<a href="https://colab.research.google.com/github/satojkovic/ToT-Colab/blob/main/tree_of_thoughts_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tree of Thoughts (ToT) Demo

This notebook demonstrates the implementation of the Tree of Thoughts algorithm.  
Paper: [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601)

## Algorithm Overview

Tree of Thoughts differs from conventional single-shot reasoning by constructing a **tree structure of thoughts** to solve problems:

1. **Thought Generation**: Generate multiple thought candidates at each step
2. **State Evaluation**: Evaluate the value of each thought state
3. **Selection**: Select the most promising thoughts for the next step
4. **Search**: Systematically explore for optimal solutions using breadth-first search

## 1. Environment Setup and Library Installation

In [None]:
# Install required libraries
!pip install openai sympy pandas numpy matplotlib seaborn
!pip install tree-of-thoughts-llm

In [None]:
# Clone official code (alternative)
# !git clone https://github.com/princeton-nlp/tree-of-thought-llm.git
# %cd tree-of-thought-llm
# !pip install -e .

In [None]:
import os
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any

# OpenAI API setup
import openai
from google.colab import userdata

# Set OpenAI API key (using Google Colab Secrets)
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

## 2. Basic Usage Example of Tree of Thoughts

In [None]:
from tot.methods.bfs import solve
from tot.tasks.game24 import Game24Task

# Parameter configuration
args = argparse.Namespace(
    backend='gpt-4',
    temperature=0.7,
    task='game24',
    naive_run=False,
    prompt_sample=None,
    method_generate='propose',
    method_evaluate='value',
    method_select='greedy',
    n_generate_sample=1,
    n_evaluate_sample=3,
    n_select_sample=5
)

print("Tree of Thoughts configuration:")
print(f"- Model: {args.backend}")
print(f"- Thought generation: {args.method_generate}")
print(f"- State evaluation: {args.method_evaluate}")
print(f"- Selection method: {args.method_select}")
print(f"- Number of candidates: {args.n_select_sample}")

## 3. Game24 Task Demonstration

Game24 is a mathematical puzzle where you use four numbers to make 24.

In [None]:
# Initialize Game24 task
task = Game24Task()

# Display task details
print(f"Dataset size: {len(task)}")
print(f"Example problem: {task.get_input(0)}")
print(f"Number of search steps: {task.steps}")

# Display first 5 problems
print("\nExample problems:")
for i in range(5):
    print(f"Problem {i+1}: {task.get_input(i)}")

In [None]:
# Solve single problem
problem_idx = 0  # Index of the problem to solve
input_numbers = task.get_input(problem_idx)

print(f"Problem: {input_numbers}")
print("Solving...\n")

# Solve with ToT algorithm
solutions, info = solve(args, task, problem_idx)

print("\n=== Solution Results ===")
for i, solution in enumerate(solutions):
    print(f"Solution {i+1}:")
    print(solution)
    print(f"Correct: {task.test_output(problem_idx, solution)}")
    print()

## 4. Algorithm Process Visualization

In [None]:
def visualize_search_process(info: Dict[str, Any]):
    """
    Visualize the search process
    """
    steps = info['steps']
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Tree of Thoughts Search Process', fontsize=16)
    
    # 1. Number of candidates at each step
    step_nums = []
    candidate_counts = []
    selected_counts = []
    
    for step_info in steps:
        step_nums.append(step_info['step'])
        candidate_counts.append(len(step_info['new_ys']))
        selected_counts.append(len(step_info['select_new_ys']))
    
    axes[0, 0].bar(step_nums, candidate_counts, alpha=0.7, label='Generated candidates')
    axes[0, 0].bar(step_nums, selected_counts, alpha=0.7, label='Selected candidates')
    axes[0, 0].set_xlabel('Step')
    axes[0, 0].set_ylabel('Number of candidates')
    axes[0, 0].set_title('Number of candidates at each step')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Distribution of evaluation values
    all_values = []
    for step_info in steps:
        all_values.extend(step_info['values'])
    
    axes[0, 1].hist(all_values, bins=20, alpha=0.7, edgecolor='black')
    axes[0, 1].set_xlabel('Evaluation value')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Distribution of evaluation values')
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Highest evaluation value at each step
    max_values = []
    avg_values = []
    
    for step_info in steps:
        values = step_info['values']
        max_values.append(max(values) if values else 0)
        avg_values.append(np.mean(values) if values else 0)
    
    axes[1, 0].plot(step_nums, max_values, marker='o', label='Maximum value')
    axes[1, 0].plot(step_nums, avg_values, marker='s', label='Average value')
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Evaluation value')
    axes[1, 0].set_title('Evaluation value progression by step')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Evaluation values of selected candidates
    selected_values = []
    for step_info in steps:
        values = step_info['values']
        select_count = len(step_info['select_new_ys'])
        if values:
            sorted_values = sorted(values, reverse=True)
            selected_values.append(sorted_values[:select_count])
    
    if selected_values:
        for i, step_values in enumerate(selected_values):
            axes[1, 1].scatter([i] * len(step_values), step_values, alpha=0.7)
    
    axes[1, 1].set_xlabel('Step')
    axes[1, 1].set_ylabel('Evaluation value')
    axes[1, 1].set_title('Evaluation values of selected candidates')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Visualize search process
if 'info' in locals():
    visualize_search_process(info)

## 5. Performance Evaluation on Multiple Problems

In [None]:
def evaluate_multiple_problems(task, args, start_idx=0, end_idx=10):
    """
    Evaluate ToT performance on multiple problems
    """
    results = []
    
    for i in range(start_idx, min(end_idx, len(task))):
        print(f"Problem {i+1}/{end_idx}: {task.get_input(i)}")
        
        try:
            solutions, info = solve(args, task, i, to_print=False)
            
            # Evaluate best solution
            best_solution = solutions[0] if solutions else ""
            test_result = task.test_output(i, best_solution)
            
            results.append({
                'problem_idx': i,
                'input': task.get_input(i),
                'solution': best_solution,
                'correct': test_result['r'],
                'steps': len(info['steps']) if 'steps' in info else 0
            })
            
            print(f"Result: {'Correct' if test_result['r'] else 'Incorrect'}")
            print()
            
        except Exception as e:
            print(f"Error: {e}")
            results.append({
                'problem_idx': i,
                'input': task.get_input(i),
                'solution': '',
                'correct': 0,
                'steps': 0
            })
    
    return results

# Run multiple problem evaluation
print("Starting performance evaluation on multiple problems...")
evaluation_results = evaluate_multiple_problems(task, args, start_idx=0, end_idx=5)

# Analyze results
correct_count = sum(1 for r in evaluation_results if r['correct'])
total_count = len(evaluation_results)
accuracy = correct_count / total_count if total_count > 0 else 0

print(f"\n=== Evaluation Results ===")
print(f"Correct answers: {correct_count}/{total_count}")
print(f"Accuracy: {accuracy:.2%}")

# Display results in DataFrame
df_results = pd.DataFrame(evaluation_results)
print("\nDetailed results:")
print(df_results[['problem_idx', 'input', 'correct']].to_string(index=False))

## 6. Comparison of Different Algorithm Configurations

In [None]:
def compare_configurations():
    """
    Compare different ToT configurations
    """
    configurations = [
        {
            'name': 'ToT (greedy)',
            'method_generate': 'propose',
            'method_evaluate': 'value',
            'method_select': 'greedy',
            'n_select_sample': 3
        },
        {
            'name': 'ToT (sample)',
            'method_generate': 'propose',
            'method_evaluate': 'value',
            'method_select': 'sample',
            'n_select_sample': 3
        },
        {
            'name': 'ToT (wide search)',
            'method_generate': 'propose',
            'method_evaluate': 'value',
            'method_select': 'greedy',
            'n_select_sample': 5
        }
    ]
    
    comparison_results = []
    
    for config in configurations:
        print(f"\nConfiguration: {config['name']}")
        
        # Update configuration
        test_args = argparse.Namespace(**vars(args))
        for key, value in config.items():
            if key != 'name':
                setattr(test_args, key, value)
        
        # Test with small sample
        results = evaluate_multiple_problems(task, test_args, start_idx=0, end_idx=3)
        
        accuracy = sum(1 for r in results if r['correct']) / len(results)
        comparison_results.append({
            'configuration': config['name'],
            'accuracy': accuracy,
            'correct_count': sum(1 for r in results if r['correct']),
            'total_count': len(results)
        })
        
        print(f"Accuracy: {accuracy:.2%}")
    
    # Visualize results
    df_comparison = pd.DataFrame(comparison_results)
    
    plt.figure(figsize=(10, 6))
    plt.bar(df_comparison['configuration'], df_comparison['accuracy'])
    plt.title('Performance Comparison of Different ToT Configurations')
    plt.xlabel('Configuration')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    return comparison_results

# Run configuration comparison
print("Starting comparison of different algorithm configurations...")
comparison_results = compare_configurations()

print("\n=== Comparison Results ===")
for result in comparison_results:
    print(f"{result['configuration']}: {result['accuracy']:.2%} ({result['correct_count']}/{result['total_count']})")

## 7. Creating and Solving Custom Problems

In [None]:
def solve_custom_problem(numbers_str: str):
    """
    Solve custom problem
    """
    print(f"Custom problem: {numbers_str}")
    
    # Create temporary task class
    class CustomGame24Task(Game24Task):
        def __init__(self, custom_input):
            super().__init__()
            self.custom_input = custom_input
        
        def get_input(self, idx):
            return self.custom_input
    
    custom_task = CustomGame24Task(numbers_str)
    
    try:
        solutions, info = solve(args, custom_task, 0)
        
        print("\n=== Solution Results ===")
        for i, solution in enumerate(solutions):
            print(f"Solution {i+1}:")
            print(solution)
            result = custom_task.test_output(0, solution)
            print(f"Correct: {result['r'] == 1}")
            print()
        
        return solutions, info
        
    except Exception as e:
        print(f"Error: {e}")
        return [], {}

# Custom problem examples
custom_problems = [
    "1 2 3 4",
    "4 1 8 7",
    "2 3 5 6"
]

print("Solving custom problems:")
for problem in custom_problems:
    print("\n" + "="*50)
    solve_custom_problem(problem)

## 8. Summary and Future Directions

### Features of Tree of Thoughts

1. **Structured Search**: Constructs a tree structure of thoughts rather than single-shot reasoning
2. **Intermediate State Evaluation**: Evaluates states at each thought step and selects optimal paths
3. **Flexible Configuration**: Allows different strategies for generation, evaluation, and selection phases
4. **High Problem-Solving Capability**: Provides systematic approaches to complex problems

### Applicable Domains

- **Mathematical Problem Solving**: Numerical calculation problems like Game24
- **Creative Writing**: Novel and poetry composition
- **Logic Puzzles**: Crossword puzzles and similar challenges
- **Strategic Thinking**: Game theory and decision-making problems

### Future Improvements

1. **Efficiency**: Reducing computational costs and improving response times
2. **Better Evaluation Functions**: More accurate intermediate state evaluation
3. **Dynamic Search**: Adaptive search depth based on problem complexity
4. **Multimodal Support**: Handling non-text inputs