In [86]:
from operator import itemgetter
from typing import Any
from pydantic import BaseModel, Field
from openai import OpenAI
from collections import deque

# PARAMETERS
DEBUG = True
#--------------------------------------------------------------------------------
# Structure for the outputs

class Status(BaseModel):
    status: str = Field(description="The status of the chain of thoughts. Either we should continue to reach a final asnwer, we should terminate this chain because it won't reach a final answer and it is going stray. If a we are ready for a final answer (regardless of correctness) return ready", enum=["continue", "terminate", "ready"])

class Thought(BaseModel):
    """A thought on how to solve the problem"""
    thought: str = Field(description="The thought for the next step")

class Solution(BaseModel):
    """A solution to solve the problem"""
    solution: str = Field(description="The solution to the problem")

# OpenAI models low T, high T, and evaluator
class OpenAIParse(object):

    def __init__(self, model, response_format, system_prompt):
        self.client = OpenAI()
        self.model = model
        self.response_format= response_format
        self.system_prompt = system_prompt

    def __call__(self, prompt, temperature = 0):     
        completion = self.client.beta.chat.completions.parse(
            model=self.model,
            temperature=temperature,
            messages=[
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": prompt},
            ],
            response_format=self.response_format,
            #logprobs=True,
            #top_logprobs=3,
        )
        return completion.choices[0].message.parsed

Thought_system_prompt = """Given the user query and the chain of thoughts, generate the next step (a thought) to solve the problem but do not generate a solution.
Also remember that you do not have access to any outside tools or sources of knowledge."""
Thought_user_prompt = "User query: {query}\nChain of thoughts: {chain_of_thoughts}"
Thought_model = OpenAIParse("gpt-4o-mini", Thought, Thought_system_prompt)
#--------------------------------------------------------------------------------
Status_system_prompt = """
Given the user query and the chain of thoughts, evaluate the chain of thoughts and determine the status of the chain of thoughts.
The chain of thoughts is a series of thoughts that are generated to solve a problem.
The chain of thoughts can be in one of three states: continue, terminate, ready."""
Status_user_prompt = "User query: {query}\nChain of thoughts: {chain_of_thoughts}"
Status_model = OpenAIParse("gpt-4o-mini", Status, Status_system_prompt)
#--------------------------------------------------------------------------------
Solution_system_prompt = """
Given the user query and the chain of thoughts, generate the solution to the problem.
The solution is the final answer to the problem."""
Solution_user_prompt = "User query: {query}\nChain of thoughts: {chain_of_thoughts}"
Solution_model = OpenAIParse("gpt-4o-mini", Solution, Solution_system_prompt)
#--------------------------------------------------------------------------------
Final_system_prompt = """
Given the user query and the set of solutions, evaluate the solutions and determine the best solution.
Note that they may be multiple similar solutions."""
Final_user_prompt = "User query: {query}\nSolutions: {solutions}"
Final_model = OpenAIParse("gpt-4o-mini", Solution, Final_system_prompt)

# Binary tree of thoughts

class Node(object):
    def __init__(self, thought, parent = None, status: Status = None):
        self.thought = thought
        self.thoughts = [thought]
        self.status = status
        self.children = []
        self.parent = parent
        self.depth = 0
        if parent:
            self.depth = parent.depth + 1
            self.parent = parent
            self.thoughts = parent.thoughts + [thought]

class BinNode(Node):
    def __init__(self, thought, parent = None, status: Status = None, low_temp = None, high_temp = None):
        super().__init__(thought, parent, status)
        self.low_temp = low_temp
        self.high_temp = high_temp

class Tree(object):
    def __init__(self, status_model, thought_model, solution_model, max_depth: int = 5, max_children: int = 2):
        self.root = None
        self.leaves = deque([self.root]) # list of leaves
        self.status_model = status_model
        self.solution_nodes = []
        self.solutions = []
        self.max_depth = max_depth
        self.max_children = max_children
        self.thought_model = thought_model
        self.solution_model = solution_model

    def set_status(self, query):
        for node in self.leaves:
            if not node.status:
                node.status = self.status_model(
                    Status_user_prompt.format(query=query, chain_of_thoughts='\n-'.join(node.thoughts))
                    )

    def prune(self):
        removal = []
        for node in self.leaves:
            if node.status.status == "terminate":
                removal.append(node)
            elif node.status.status == "ready":
                self.solution_nodes.append(node)
                removal.append(node)
        for node in removal:
            self.leaves.remove(node)

    def __call__(self, query, T=0):
        self.root = Node(thought = query, status = Status(status="continue"))
        node = self.root
        self.leaves = deque([self.root])
        for depth in range(self.max_depth):
            while self.leaves and node.depth < depth:
                if DEBUG:
                    print(f"Current Depth: {node.depth} out of {self.max_depth}")
                node = self.leaves.popleft()
                for i in range(self.max_children):
                    thought = self.thought_model(
                        Thought_user_prompt.format(query=query, chain_of_thoughts='\n-'.join(node.thoughts)), 
                        temperature = T)
                    node.children.append(Node(thought = thought.thought, parent = node))
                self.leaves.extend(node.children)
            self.set_status(query)
            self.prune()
        self.solutions = [self.solution_model(
            Solution_user_prompt.format(query=query, chain_of_thoughts='\n-'.join(node.thoughts))
            ).solution for node in self.solution_nodes]
        if not self.solutions:
            self.solutions = ["No solution found"]
        return self.solutions

class BinTree(Tree):
    def __init__(self, status_model, thought_model, solution_model, max_depth: int = 5, low_temp = 0, high_temp = 1):
        super().__init__(status_model=status_model, thought_model=thought_model, solution_model=solution_model, max_depth=max_depth)
        self.root = None
        self.leaves = deque([self.root]) # list of leaves
        self.low_temp = low_temp
        self.high_temp = high_temp

    def __call__(self, query):
        self.root = BinNode(thought = query, status = Status(status="continue"))
        node = self.root
        self.leaves = deque([self.root])
        for depth in range(self.max_depth):
            while self.leaves and node.depth < self.max_depth:
                if DEBUG:
                    print(f"Current Depth: {node.depth} out of {self.max_depth}")
                node = self.leaves.popleft()
                thought = self.thought_model(
                    Thought_user_prompt.format(query=query, chain_of_thoughts='\n-'.join(node.thoughts)), 
                    temperature = self.low_temp)
                node.low_temp = BinNode(thought = thought.thought, parent = node)
                thought = self.thought_model(
                    Thought_user_prompt.format(query=query, chain_of_thoughts='\n-'.join(node.thoughts)), 
                    temperature = self.high_temp)
                node.high_temp = BinNode(thought = thought.thought, parent = node)
                self.leaves.extend([node.low_temp, node.high_temp])
            self.set_status(query)
            self.prune()
        self.solutions = [self.solution_model(
            Solution_user_prompt.format(query=query, chain_of_thoughts='\n-'.join(node.thoughts))
            ).solution for node in self.solution_nodes]
        if not self.solutions:
            self.solutions = ["No solution found"]
        return self.solutions

In [102]:
tree = BinTree(Status_model, Thought_model, Solution_model)
tree("What is the solution to the equation x^2 - 4 = 0?", T=1)

Current Depth: 0 out of 5
Current Depth: 0 out of 5
Current Depth: 1 out of 5
Current Depth: 1 out of 5
Current Depth: 2 out of 5
Current Depth: 2 out of 5
Current Depth: 2 out of 5
Current Depth: 2 out of 5
Current Depth: 3 out of 5
Current Depth: 3 out of 5
Current Depth: 3 out of 5
Current Depth: 3 out of 5
Current Depth: 3 out of 5
Current Depth: 3 out of 5
Current Depth: 3 out of 5
Current Depth: 3 out of 5


['x = 2 or x = -2', 'x = 2 or x = -2']

In [112]:
tree.leaves[0].depth

4

In [93]:
len(tree.leaves)

17

In [32]:
from openai import OpenAI

client = OpenAI()

completion = client.beta.chat.completions.parse(
    model="gpt-4o-mini",
    temperature=0,
    messages=[
        {"role": "system", "content": "Given the user query and the set of thoughts, generate the next step (a thought) to solve the problem but do not generate a solution. Also remember that you do not have access to any outside tools or sources of knowledge."},
        {"role": "user", "content": "Query: What is the largest city in the united states? \n Thoughts: \n"},
    ],
    response_format=Thought,
    logprobs=True,
    top_logprobs=3,
)

thought = completion.choices[0].message.parsed

In [33]:
completion.choices[0].logprobs.content

[ChatCompletionTokenLogprob(token='{"', bytes=[123, 34], logprob=0.0, top_logprobs=[TopLogprob(token='{"', bytes=[123, 34], logprob=0.0), TopLogprob(token='{', bytes=[123], logprob=-18.25), TopLogprob(token=' {"', bytes=[32, 123, 34], logprob=-22.125)]),
 ChatCompletionTokenLogprob(token='thought', bytes=[116, 104, 111, 117, 103, 104, 116], logprob=0.0, top_logprobs=[TopLogprob(token='thought', bytes=[116, 104, 111, 117, 103, 104, 116], logprob=0.0), TopLogprob(token='th', bytes=[116, 104], logprob=-20.0), TopLogprob(token='though', bytes=[116, 104, 111, 117, 103, 104], logprob=-24.0)]),
 ChatCompletionTokenLogprob(token='":"', bytes=[34, 58, 34], logprob=-1.9361265e-07, top_logprobs=[TopLogprob(token='":"', bytes=[34, 58, 34], logprob=-1.9361265e-07), TopLogprob(token='":"\'', bytes=[34, 58, 34, 39], logprob=-16.5), TopLogprob(token='":', bytes=[34, 58], logprob=-18.25)]),
 ChatCompletionTokenLogprob(token='I', bytes=[73], logprob=-0.2783952, top_logprobs=[TopLogprob(token='I', bytes=

In [44]:
[1,2,3][-1::-1]

[3, 2, 1]