In [1]:
from typing import Optional, List
from collections.abc import Callable
import re

In [2]:
from openai import OpenAI
from google.colab import userdata

client = OpenAI(api_key=userdata.get('OPENAI_API_KEY'))

In [3]:
def get_thought_gen_prompt(input_seq: str, state: str) -> str:
    """Get thought generation prompt.

    Keyword arguments:
    input_seq -- the input sequence (comprising four numbers, e.g., '1 1 1 8')
    state -- concatenation of all the thoughts so far (separated by '\n')"""

    # Reference: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py
    def get_remaining_numbers(thought: str) -> str:
        return thought.split('left: ')[-1].split(')')[0]

    if state == '': # Root node; no thoughts have been generated yet.
        remaining_numbers = input_seq
    else:
        last_thought = state.strip().split('\n')[-1]
        remaining_numbers = get_remaining_numbers(last_thought)

    if remaining_numbers != '24': # Intermediate step.
        # Reference: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/prompts/game24.py
        prompt = f'''Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 /  2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Input: {remaining_numbers}
Possible next steps:
'''
    else: # Last (output generation) step.
        # Reference: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/prompts/game24.py
        prompt = f'''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input_seq}
Steps:
{state}
'''
    return prompt

In [4]:
def get_state_eval_prompt(input_seq: str, state: str) -> str:
    """Get state evaluation prompt.

    Keyword arguments:
    input_seq -- the input sequence (comprising four numbers, e.g., '1 1 1 8')
    state -- concatenation of all the thoughts so far (separated by '\n')"""

    # Reference: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py
    def get_remaining_numbers(thought: str) -> str:
        return thought.split('left: ')[-1].split(')')[0]

    last_line = state.strip().split('\n')[-1]

    if 'left: ' not in last_line: # Last (output generation) step.
        ans = last_line.lower().replace('answer: ', '')
        # Reference: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/prompts/game24.py
        prompt = f'''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge:
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge:
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge:
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge:
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge:
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge:
impossible
Input: {input_seq}
Answer: {ans}
Judge:'''
    else: # Intermediate step.
        remaining_numbers = get_remaining_numbers(last_line)
        # Reference: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/prompts/game24.py
        prompt = f'''Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 11 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{remaining_numbers}
'''
    return prompt

In [5]:
# Reference: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/tasks/game24.py
def heuristic_calculator(state: str, state_evals: List[str]) -> float:
    if len(state.strip().split('\n')) == 4 and 'answer' not in state.lower(): # Such a state is undesirable.
        return 0
    value_names = [_.split('\n')[-1].lower() for _ in state_evals] # A list containing 'impossible' / 'likely' / 'sure' values.
    value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20} # Ad hoc.
    value = sum(value * value_names.count(name) for name, value in value_map.items())
    return value

In [6]:
class FewShotSearchGameOf24DFS:
    def __init__(
            self,
            client: OpenAI,
            model: str,
            messages: List[dict],
            input_seq: str,
            get_thought_gen_prompt: Callable,
            get_state_eval_prompt: Callable,
            heuristic_calculator: Callable,
            n_evals: int = 3
    ):
        self.client = client
        self.model = model
        self.messages = messages
        self.input_seq = input_seq
        self.messages.append({'role': "user", 'content': input_seq})
        self.get_thought_gen_prompt = get_thought_gen_prompt # For consumption of the thought generator.
        self.get_state_eval_prompt = get_state_eval_prompt # For consumption of the state evaluator.
        self.heuristic_calculator = heuristic_calculator # For consumption of the state evaluator.
        self.n_evals = n_evals # For consumption of the state evaluator.

    def chat_completion(
            self,
            messages: List[dict],
            model: str,
            temperature: float = 0.2,
            max_tokens: int = 4096,
            n: int = 1,
            stop: str = "\nObservation:",
            **kwargs
    ) -> str:
        response = self.client.chat.completions.create(
            messages=messages,
            model=model,
            temperature=temperature,
            max_tokens=max_tokens,
            n=n,
            stop=stop,
            **kwargs
        )
        return response.choices[0].message.content

    def thought_generator(self, state: str) -> str: # Tool.
        prompt = self.get_thought_gen_prompt(self.input_seq, state)
        # The text generation parameters are the same as: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/models.py
        thoughts = self.chat_completion([{'role': "user", 'content': prompt}], model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None)
        # Note: Due to the template of the thought generation prompt, each thought will appear on a new line.
        return thoughts

    def state_evaluator(self, state: str) -> float: # Tool.
        prompt = self.get_state_eval_prompt(self.input_seq, state)
        state_evals = []
        for _ in range(self.n_evals):
            # The text generation parameters are the same as: https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/models.py
            state_eval = self.chat_completion([{'role': "user", 'content': prompt}], model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None)
            state_evals.append(state_eval)
        value = self.heuristic_calculator(state, state_evals)
        return value

    def backtrack(self, from_node: str, to_node: str, flag: bool) -> str:
        """A dummy tool that doesn't do anything in this notebook.
        However, in the evaluation notebook, the `backtrack` method is used to perform various 'correctness checks' related to backtracking."""

        return "Ok."

    def react_loop(self, step_limit: Optional[int] = None) -> str:
        i = 0
        while True:
            if step_limit is not None and i >= step_limit:
                print("Step limit reached! Breaking the ReAct loop...")
                break
            completion = self.chat_completion(self.messages, model=self.model)
            print(completion)
            print("---")
            if "We're now outside the tree." in completion:
                self.messages.append({'role': "assistant", 'content': completion})
                return "-- end of search --"
            else: # A tool call was made.
                tool_call_string = completion.split('Action: ')[-1].strip()
                if 'thought_generator' in tool_call_string or 'state_evaluator' in tool_call_string or 'backtrack' in tool_call_string:
                    tool_call_string = tool_call_string.replace('thought_generator', 'self.thought_generator')
                    tool_call_string = tool_call_string.replace('state_evaluator', 'self.state_evaluator')
                    tool_call_string = tool_call_string.replace('backtrack', 'self.backtrack')
                    result = eval(tool_call_string)
                    print(result)
                    print("---")
                else:
                    raise AssertionError("Invalid tool call string!")
                completion += "\nObservation:\n"
                self.messages.append({'role': "assistant", 'content': completion})
                self.messages.append({'role': "user", 'content': str(result)})
            i += 1

Let's fetch the few-shot examples.

In [7]:
!git clone https://github.com/sambitmukherjee/few-shot-search.git

fatal: destination path 'few-shot-search' already exists and is not an empty directory.


In [8]:
%cd few-shot-search

/content/few-shot-search


In [9]:
from few_shot_examples.game_of_24_dfs_4_shot_examples import *

In [10]:
len(messages)

210

In [11]:
messages[0]

{'role': 'user',
 'content': 'Four numbers: 1 1 9 9\n\nSearch parameters:-\nHeuristic threshold: 3.0\nLevel limit: 3'}

In [12]:
messages[-1]

{'role': 'assistant',
 'content': "---\n~~~\nWe're now outside the tree.\n\nFlag received: True"}

In [13]:
system_prompt = """Your task is to role-play a particular type of tree search algorithm.

To help you mimic this algorithm successfully, you will be provided few-shot examples. Study these examples carefully to learn how to mimic the algorithm successfully.
You must adhere to the reasoning style and choice of words in these few-shot examples. Do not use any new words that aren't present in the few-shot examples.

If at any point, you think that you have found the final solution (i.e., the search is over), you must still backtrack all the way up the tree (as shown in the few-shot examples). Do not fail to do this."""

In [14]:
messages = [{'role': "system", 'content': system_prompt}] + messages

First, let's check the search trace on four numbers that cannot reach 24. (Please feel free to replace the below puzzle with a different set of four numbers that cannot reach 24.)

In [15]:
new_input_seq1 = '''Four numbers: 1 8 12 13

Search parameters:-
Heuristic threshold: 3.0
Level limit: 3'''

In [16]:
few_shot_search = FewShotSearchGameOf24DFS(client, "gpt-4o", messages, new_input_seq1, get_thought_gen_prompt, get_state_eval_prompt, heuristic_calculator)

In [17]:
few_shot_search.react_loop()

Starting the tree search...
---
~~~
Current level: 0
Index of current node: 0
State of current node:
'' (empty string; no thoughts have been generated yet as we're still at the root node)

Using the thought generator tool to obtain thought candidates...

Action: thought_generator('')

---
1 + 8 = 9 (left: 9 12 13)
8 - 1 = 7 (left: 7 12 13)
12 - 1 = 11 (left: 8 11 13)
13 - 1 = 12 (left: 8 12 12)
1 * 8 = 8 (left: 8 12 13)
12 / 1 = 12 (left: 8 12 13)
13 / 1 = 13 (left: 8 12 13)
12 + 1 = 13 (left: 8 13 13)
13 - 8 = 5 (left: 1 5 12)
13 - 12 = 1 (left: 1 1 8)
8 * 1 = 8 (left: 8 12 13)
12 - 8 = 4 (left: 1 4 13)
13 / 8 = 1.625 (left: 1 1.625 12)
12 * 1 = 12 (left: 8 12 13)
13 * 1 = 13 (left: 8 12 13)
---
Each of the above thought candidates has been added as a child node of the current node:

Node 0_0: 1 + 8 = 9 (left: 9 12 13)
Node 0_1: 8 - 1 = 7 (left: 7 12 13)
Node 0_2: 12 - 1 = 11 (left: 8 11 13)
Node 0_3: 13 - 1 = 12 (left: 8 12 12)
Node 0_4: 1 * 8 = 8 (left: 8 12 13)
Node 0_5: 12 / 1 = 1

'-- end of search --'

Now, let's check the search trace on four numbers that can reach 24. (Please feel free to replace the below puzzle with a different puzzle from: https://github.com/sambitmukherjee/tree-of-thought-llm/blob/master/src/tot/data/24/24.csv)

In [18]:
new_input_seq2 = '''Four numbers: 11 11 11 13

Search parameters:-
Heuristic threshold: 3.0
Level limit: 3'''

In [19]:
few_shot_search = FewShotSearchGameOf24DFS(client, "gpt-4o", messages, new_input_seq2, get_thought_gen_prompt, get_state_eval_prompt, heuristic_calculator)

In [20]:
few_shot_search.react_loop()

Starting the tree search...
---
~~~
Current level: 0
Index of current node: 0
State of current node:
'' (empty string; no thoughts have been generated yet as we're still at the root node)

Using the thought generator tool to obtain thought candidates...

Action: thought_generator('')

---
11 + 11 = 22 (left: 11 13 22)
11 * 11 = 121 (left: 11 13 121)
13 - 11 = 2 (left: 2 11 11)
11 / 11 = 1 (left: 1 11 13)
13 + 11 = 24 (left: 11 11 24)
13 * 11 = 143 (left: 11 11 143)
---
Each of the above thought candidates has been added as a child node of the current node:

Node 0_0: 11 + 11 = 22 (left: 11 13 22)
Node 0_1: 11 * 11 = 121 (left: 11 13 121)
Node 0_2: 13 - 11 = 2 (left: 2 11 11)
Node 0_3: 11 / 11 = 1 (left: 1 11 13)
Node 0_4: 13 + 11 = 24 (left: 11 11 24)
Node 0_5: 13 * 11 = 143 (left: 11 11 143)

Looping through the child nodes one at a time...
---
Reminder:-
Current level: 0
Index of current node: 0
State of current node:
'' (empty string; no thoughts have been generated yet as we're sti

'-- end of search --'