Paper: Planning with Large Language Models for Code Generation, ICLR 2023

Large Language Models (LLMs) often generate SQL queries that contain subtle but critical errors. To solve this, I implement a tree-search algorithm that explores different ways to construct a query. The algorithm builds a query token by token, treating each addition as a step in a plan. At each step, it uses the LLM to suggest promising next tokens. It then evaluates the quality of a potential query by attempting to execute it against a database, which provides a direct "reward" signal. By repeating this process, the algorithm learns to navigate away from erroneous paths and focuses its search on sequences that lead to valid, executable SQL.

The original paper applies its Planning-Guided Transformer Decoding (PG-TD) algorithm to generate Python code for competitive programming challenges. It uses the pass rate on unit tests as its reward signal to guide the search. I adapt this core framework directly. My implementation replaces the Python code generation task with SQL query generation. Instead of using unit tests, I design a new reward function that checks for SQL syntax and execution errors against a live database, demonstrating the versatility of the paper's planning-based approach.

In [8]:
import sqlite3
import math
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

db_path = "academic.sqlite"

In [9]:
def get_schema(db_path):
    """
    Connects to the SQLite database and extracts the schema of all tables.
    Returns a formatted string describing the tables and their columns.
    """
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()

        # Gets list of tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()

        schema_string = ""
        for table in tables:
            table_name = table[0]
            schema_string += f"Table {table_name}: ("

            # Gets column info for each table
            cursor.execute(f"PRAGMA table_info({table_name});")
            columns = cursor.fetchall()

            col_names = [col[1] for col in columns]
            schema_string += ", ".join(col_names) + ")\n"

        conn.close()
        return schema_string
    except sqlite3.Error as e:
        print(f"Database error: {e}")
        return None


def generate_baseline_sql(question, db_path):
    """Generates SQL using a simple, one-shot prompt."""
    schema = get_schema(db_path)
    if not schema:
        return "Error: Could not get database schema."

    prompt = f"""### Instructions:
Given the following database schema, write a SQL query that answers the question.

### Schema:
{schema}
### Question:
{question}

### SQL Query:
SELECT"""

    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        **inputs,
        max_new_tokens=100,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id
    )

    # Cleaning the output
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    sql_query = "SELECT" + generated_text.split("### SQL Query:\nSELECT")[-1].split('\n')[0]

    return sql_query.strip()

In [10]:
def get_reward(sql_query, db_path):
    """
    Executes a SQL query and returns a reward score.
    1.0 for success, 0.5 for execution error, 0.0 for an empty query.
    """
    if not sql_query or sql_query.isspace():
        return 0.0

    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute(sql_query)
        conn.close()
        return 1.0
    except sqlite3.Error:
        return 0.5

In [11]:
class TreeNode:
    """A node in the Monte Carlo Search Tree."""
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []

        self.visits = 0
        self.total_reward = 0.0

    def is_fully_expanded(self):
        # A node is fully expanded if it has any children.
        return len(self.children) > 0

    def average_reward(self):
        """Calculates the average reward of the node."""
        if self.visits == 0:
            return 0
        return self.total_reward / self.visits

In [12]:
def select(node):
    """
    Selects the best child node
    This balances exploration and exploitation.
    """
    best_child = None
    best_ucb_score = -1

    # UCB1 formula: avg_reward + C * sqrt(log(parent_visits) / child_visits)
    exploration_constant = 1.5

    for child in node.children:
        if child.visits == 0:
            # If a child has not been visited, it's a top priority
            return child

        exploitation_term = child.average_reward()
        exploration_term = exploration_constant * math.sqrt(
            math.log(node.visits) / child.visits
        )
        ucb_score = exploitation_term + exploration_term

        if ucb_score > best_ucb_score:
            best_ucb_score = ucb_score
            best_child = child

    return best_child

def expand(node, model, tokenizer, top_k=5):
    """
    Generates potential next tokens (actions) from the current state
    and creates new child nodes for them.
    """
    if node.state.strip().endswith(";"):
        return

    prompt = node.state
    inputs = tokenizer(prompt, return_tensors='pt')

    with torch.no_grad():
        outputs = model(**inputs)

    # Probabilities for the next token
    next_token_logits = outputs.logits[:, -1, :]
    top_k_tokens = torch.topk(next_token_logits, top_k, dim=-1).indices.squeeze().tolist()

    for token_id in top_k_tokens:
        new_state = node.state + tokenizer.decode(token_id)
        child = TreeNode(state=new_state, parent=node)
        node.children.append(child)

def evaluate(node, model, tokenizer, db_path, max_new_tokens_to_generate=30):
    """
    This is the "rollout". It completes the query from the current node's
    partial state and gets a reward.
    """
    prompt = node.state
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens_to_generate,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )

    completed_query = tokenizer.decode(outputs[0], skip_special_tokens=True)

    if "###" in completed_query:
        # Considering the part after the initial prompt
        completed_query = completed_query[len(prompt):]
        if "###" in completed_query:
            completed_query = completed_query.split("###")[0]
        completed_query = prompt + completed_query

    return get_reward(completed_query.strip(), db_path)

def backpropagate(node, reward):
    """
    Updates the visit counts and total rewards of a node and its ancestors.
    """
    current = node
    while current is not None:
        current.visits += 1
        current.total_reward += reward
        current = current.parent

In [13]:
def pg_td_search(question, db_path, iterations=20):
    """
    Performs the Planning-Guided Tree Search to generate a SQL query.
    """
    schema = get_schema(db_path)
    initial_prompt = f"""### Schema:
{schema}
### Question:
{question}

### SQL Query:
"""
    root = TreeNode(state=initial_prompt)

    for i in range(iterations):
        node = root

        # 1. Selection Phase: Traverse down the tree
        while node.is_fully_expanded() and node.children:
            node = select(node)

        # 2. Expansion Phase: If the node is not fully expanded, expand it
        if not node.is_fully_expanded():
            expand(node, model, tokenizer)

        # 3. Evaluation Phase: Choose a child to evaluate
        if node.children:
            child_to_evaluate = node.children[0] # Simple choice: evaluate the first new child
            reward = evaluate(child_to_evaluate, model, tokenizer, db_path)

            # 4. Backpropagation Phase
            backpropagate(child_to_evaluate, reward)

    # After all iterations, find the best query by choosing the most visited path
    best_query_node = root
    while best_query_node.children:
        # Choose child with the highest number of visits
        best_query_node = max(best_query_node.children, key=lambda n: n.visits)

    return best_query_node.state.split("SQL Query:\n")[-1].strip()

In [14]:
import pandas as pd

test_questions = [
    "How many students are there?",
    "What are the names of all departments?",
    "Find the number of courses taught by the professor with the last name 'Smith'."
]

results = []
for q in test_questions:
    print(f"Processing question: '{q}'")

    # Baseline Result
    baseline_sql = generate_baseline_sql(q, db_path)
    baseline_reward = get_reward(baseline_sql, db_path)

    # PG-TD Result
    pgtd_sql = pg_td_search(q, db_path, iterations=30)
    pgtd_reward = get_reward(pgtd_sql, db_path)

    results.append({
        "Question": q,
        "Baseline SQL": baseline_sql,
        "Baseline Reward": baseline_reward,
        "PG-TD SQL": pgtd_sql,
        "PG-TD Reward": pgtd_reward
    })

print("\nComparison Results:")
df = pd.DataFrame(results)
display(df)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Processing question: 'How many students are there?'


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Processing question: 'What are the names of all departments?'


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Processing question: 'Find the number of courses taught by the professor with the last name 'Smith'.'

Comparison Results:


Unnamed: 0,Question,Baseline SQL,Baseline Reward,PG-TD SQL,PG-TD Reward
0,How many students are there?,"SELECT * FROM (id, name, oid) WHERE id = 'id' ...",0.5,,0.0
1,What are the names of all departments?,"SELECT * FROM (id, name, oid) WHERE (id = 'id'...",0.5,,0.0
2,Find the number of courses taught by the profe...,"SELECT * FROM (id, name, oid) WHERE (id = 'Smi...",0.5,,0.0
