##Guided Token Generation Code

This code runs in Google Colab. To run this code, make sure you have all the files related to the SimpleSQL Grammar, including antlr.jar. Otherwise, you can use the SimpleSQL.g4 file to generate these files (uncomment code blocks as necessary).

1. Download Spider dataset from Kaggle

In [1]:
# Ref: https://www.kaggle.com/datasets/jeromeblanchet/yale-universitys-spider-10-nlp-dataset/data

import kagglehub
dataset_path = kagglehub.dataset_download("jeromeblanchet/yale-universitys-spider-10-nlp-dataset")

2. Preview data

In [2]:
import pandas as pd
import json
import os

spider_path = os.path.join(dataset_path, "spider")
data_path = os.path.join(spider_path, "train_spider.json")

with open(data_path, "r") as f:
  data = json.load(f)

df = pd.DataFrame(data)

df.head(10)

Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql
0,department_management,SELECT count(*) FROM head WHERE age > 56,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...",How many heads of the departments are older th...,"[How, many, heads, of, the, departments, are, ...","{'except': None, 'from': {'conds': [], 'table_..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","List the name, born state and age of the heads...","[List, the, name, ,, born, state, and, age, of...","{'except': None, 'from': {'conds': [], 'table_..."
2,department_management,"SELECT creation , name , budget_in_billions ...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","List the creation year, name and budget of eac...","[List, the, creation, year, ,, name, and, budg...","{'except': None, 'from': {'conds': [], 'table_..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...","[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...",What are the maximum and minimum budget of the...,"[What, are, the, maximum, and, minimum, budget...","{'except': None, 'from': {'conds': [], 'table_..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...",What is the average number of employees of the...,"[What, is, the, average, number, of, employees...","{'except': None, 'from': {'conds': [], 'table_..."
5,department_management,SELECT name FROM head WHERE born_state != 'Cal...,"[SELECT, name, FROM, head, WHERE, born_state, ...","[select, name, from, head, where, born_state, ...",What are the names of the heads who are born o...,"[What, are, the, names, of, the, heads, who, a...","{'except': None, 'from': {'conds': [], 'table_..."
6,department_management,SELECT DISTINCT T1.creation FROM department AS...,"[SELECT, DISTINCT, T1.creation, FROM, departme...","[select, distinct, t1, ., creation, from, depa...",What are the distinct creation years of the de...,"[What, are, the, distinct, creation, years, of...","{'except': None, 'from': {'conds': [[False, 2,..."
7,department_management,SELECT born_state FROM head GROUP BY born_stat...,"[SELECT, born_state, FROM, head, GROUP, BY, bo...","[select, born_state, from, head, group, by, bo...",What are the names of the states where at leas...,"[What, are, the, names, of, the, states, where...","{'except': None, 'from': {'conds': [], 'table_..."
8,department_management,SELECT creation FROM department GROUP BY creat...,"[SELECT, creation, FROM, department, GROUP, BY...","[select, creation, from, department, group, by...",In which year were most departments established?,"[In, which, year, were, most, departments, est...","{'except': None, 'from': {'conds': [], 'table_..."
9,department_management,"SELECT T1.name , T1.num_employees FROM depart...","[SELECT, T1.name, ,, T1.num_employees, FROM, d...","[select, t1, ., name, ,, t1, ., num_employees,...",Show the name and number of employees for the ...,"[Show, the, name, and, number, of, employees, ...","{'except': None, 'from': {'conds': [[False, 2,..."


3. The next couple code blocks are for generating the antlr.jar and other parsing related files using SimpleSQL.g4

In [3]:
# Ref: https://github.com/jszheng/py3antlr4book
# Ref: https://www.antlr.org/download.html
# !pip install antlr4-python3-runtime

# !wget https://www.antlr.org/download/antlr-4.9.3-complete.jar -O antlr.jar

# ANTLR_JAR = "antlr.jar"

In [4]:
# !rm -f SimpleSQLParser.py SimpleSQLParser.tokens SimpleSQLParser.interp
# !rm -f SimpleSQLListener.py SimpleSQLVisitor.py
# !rm -f SimpleSQL.interp SimpleSQL.tokens

In [5]:
# Ref: https://dzone.com/articles/building-sql-to-dataframe-converter-with-antlr

In [6]:
# !java -jar antlr.jar -Dlanguage=Python3 SimpleSQL.g4

4. The following Error Listener will be used to parse the tree (SimpleSQL Grammar) with a SQL prefix and accumulate syntax errors.

In [7]:
# Ref: https://dzone.com/articles/how-to-perform-custom-error-handling-with-antlr
from antlr4.error.ErrorListener import ErrorListener

class GuidedTokenErrorListener(ErrorListener):
  def __init__(self):
    self.syntax_errors = []

  def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
    self.syntax_errors.append(msg)

5. Get Spider tables data

In [8]:
import os
import re
import json

tables_data = os.path.join(os.path.join(dataset_path, "spider"), "tables.json")

with open(tables_data, "r") as f:
  spider_tables = json.load(f)

6. Guided Token Generator class. This class is responsible for producing a SQL query given a prompt and DB.

In [9]:
from SimpleSQLLexer import SimpleSQLLexer
from SimpleSQLParser import SimpleSQLParser
from antlr4 import InputStream, CommonTokenStream

class GuidedTokenGenerator:
  def __init__(self, pii_tokens, llm, get_schema):
    self.pii_tokens = {t.lower() for t in pii_tokens}
    self.llm = llm
    self.get_schema = get_schema

    self.user_prompt = None
    self.db = None
    self.generated_tokens = []
    self.candidate_tokens = []
    self.is_finished = False

    self.schema_text = None
    self.schema_tables = None
    self.schema_cols = set()

    self.exhausted_paths = {}

  def start(self, prompt, db=None):
    self.user_prompt = prompt
    self.db = db

    # store DB schema
    schema_dict = self.get_schema(self.db)
    self.schema_tables = schema_dict["tables"]
    self.schema_text = ("Database schema:\n"
      f"Tables: {', '.join(schema_dict['tables'])}\n"
      "Table Columns:\n" +
      "\n".join(f"  - {table}: {', '.join(cols)}"
        for table, cols in schema_dict["table_columns"].items()
      )
    )

    # store all schema-related Tokens
    tokens = set()
    for cols in schema_dict["table_columns"].values():
      tokens.update(cols)
    self.schema_cols = {t.lower() for t in tokens}

    # reset all state
    self.generated_tokens = []
    self.candidate_tokens = []
    self.is_finished = False
    self.exhausted_paths = {}

    return self.db

  # extract tokens based on what is expected from the Simple SQL Grammar
  def extract_from_errors(self, errors):
    if not errors:
      return set()

    errors_match = re.search(r"expecting (.*)", errors[0])
    if not errors_match:
      return set()

    text = errors_match.group(1).strip().strip("{}")
    unfiltered = re.findall(r"'[^']+'|[A-Za-z_]+", text)
    return {token.strip("'") for token in unfiltered}

  def next_grammatical_tokens(self):
    prefix = " ".join(self.generated_tokens)
    input_stream = InputStream(prefix)

    lexer = SimpleSQLLexer(input_stream)
    token_stream = CommonTokenStream(lexer)

    parser = SimpleSQLParser(token_stream)

    error_listener = GuidedTokenErrorListener()
    parser.removeErrorListeners()
    parser.addErrorListener(error_listener)

    try:
      parser.query()
    except:
      pass

    if error_listener.syntax_errors:
      return self.extract_from_errors(error_listener.syntax_errors)
    return set()

  def safe_tokens(self):
    last_token = self.generated_tokens[-1] if self.generated_tokens else ""

    if last_token == "FROM":
      return {table_name for table_name in self.schema_tables if table_name not in self.pii_tokens}

    expanded_tokens = set()

    # manually adding operators because parse tree is not great
    operators = {"AND", "OR", "=", "!=", "<", ">", "<=", ">=", "IN", "LIKE", "BETWEEN"}
    if sum(1 for t in self.generated_tokens if t in operators) < sum(1 for t in self.generated_tokens if t in {"WHERE", "HAVING"}):
      expanded_tokens.update(operators)

    if last_token in self.schema_tables:
      expanded_tokens.update({"WHERE", "GROUP", "LIMIT", "JOIN"})

    grammar_tokens = self.next_grammatical_tokens()
    for token in grammar_tokens:
      if token == "IDENT":
        expanded_tokens.update(self.schema_cols)
        continue
      if token == "FROM":
        expanded_tokens.update(',')
      expanded_tokens.add(token)

    safe_tokens = {token for token in expanded_tokens if token not in self.pii_tokens}

    index = len(self.generated_tokens)
    if index in self.exhausted_paths:
      safe_tokens = {token for token in safe_tokens if token not in self.exhausted_paths[index]}

    return safe_tokens

  def ask_llm(self):
    token_list = sorted(self.candidate_tokens)

    prompt = (
      "You are selecting the NEXT SQL TOKEN for an autocomplete system.\n"
      "You MUST behave like a deterministic parser, not a natural language model.\n"
      f"User question: {self.user_prompt}\n"
      f"Current SQL prefix: \"{" ".join(self.generated_tokens)}\"\n\n"
      f"Tokens already selected so far: {", ".join(self.generated_tokens)}\n"
      f"Database schema:\n{self.schema_text}\n\n"
      f"Valid next tokens (choose ONLY from this list):{", ".join(token_list)}\n"
      "IMPORTANT RULES (you MUST obey all):"
      "1. You MUST choose exactly ONE token from the list.\n"
      "2. NEVER select a column that is already present in the SQL prefix after SELECT.\n"
      "3. NEVER repeat the same column twice.\n"
      "4. The Current SQL Prefix + the Token you choose MUST answer the user's question.\n"
      "4. If the SQL Query is incomplete, or the user's question has not been answered, you MUST CHOOSE A TOKEN.\n\n"

      "YOU MUST return exactly one token FROM THE TOKEN LIST, or <STOP> (look at STOP CONDITION below). DO NOT RETURN ANY OTHER STRINGS.\n"

      "STOP CONDITION:\n"
      "If the current SQL prefix already forms a complete SQL query AND it fully answers the user's question.\n"
      "(e.g., SELECT ... FROM ..., or SELECT ... FROM ... WHERE col = literal, or SELECT ... FROM ... WHERE col = literal LIMIT 100),\n"
      "then return ONLY: <STOP>.\n"
      "Only return <STOP> if the SQL ALREADY contains both: SELECT <columns>, FROM <table>\n"
      "Do NOT return <STOP> after just SELECT, or SELECT column, or SELECT column FROM.\n"
    )

    llm_output = self.llm(prompt).strip().replace('"', '').replace("'", "").strip()

    print("LLM Choice: ", llm_output)

    # LLM wants to STOP
    if llm_output == "<STOP>":
      tokens = self.generated_tokens
      if self.is_complete(tokens):
        self.is_finished = True
      return None

    if "USER_DEFINED_NUMBER" in token_list and re.fullmatch(r"\d+", llm_output):
      return llm_output

    if "USER_DEFINED_STRING" in token_list:
      return llm_output

    # only accept valid tokens
    if llm_output in token_list:
      return llm_output
    return None

  # necessary for backtracking
  def exhaust_token(self, index, token):
    if index not in self.exhausted_paths:
        self.exhausted_paths[index] = set()
    if token is not None: # we have exhausted all paths that take token at sql_query[index]
        self.exhausted_paths[index].add(token)

  def forward(self):
    prefix = " ".join(self.generated_tokens) # what has been generated so far/confirmed
    self.candidate_tokens = self.safe_tokens() # tokens that can come next (candidates)

    if not self.candidate_tokens: # there are no more grammatically accurate tokens that can be generated
        tokens = prefix.split()
        if self.is_complete(tokens): # check if the SQL query can be marked as complete
          self.is_finished = True
          return {"status": "SQL query complete", "LLM choice": "N/A"}
        # backtrack
        exhausted_token = self.generated_tokens.pop()
        self.exhaust_token(len(self.generated_tokens), exhausted_token)
        return {"status": "Backtrack", "LLM choice": "N/A"}

    llm_choice = self.ask_llm() # ask LLM to choose from candidate token list

    # if choice is invalid or LLM didn't choose at all, backtrack or mark as complete
    if llm_choice is None:
      if self.is_finished:
        return {"status": "SQL query complete", "LLM choice": llm_choice}
      # backtrack
      exhausted_token = self.generated_tokens.pop()
      self.exhaust_token(len(self.generated_tokens), exhausted_token)
      return {"status": "Backtrack", "LLM choice": llm_choice}

    self.generated_tokens.append(llm_choice)
    return {"LLM choice": llm_choice, "SQL prefix": list(self.generated_tokens)}

  def is_complete(self, tokens):
    if "SELECT" not in tokens or "FROM" not in tokens:
      return False

    if tokens.index("FROM") == len(tokens) - 1:
      return False

    if "WHERE" in tokens:
      if len(tokens) - tokens.index("WHERE") < 4:
        return False
    return True

  def finished(self):
    return self.is_finished


7. Configure function to invoke Gemini LLM. You can replace this with any other LLM function.

In [None]:
import google.generativeai as genai

genai.configure(api_key="<INSERT API KEY>")

llm = genai.GenerativeModel("gemini-2.5-flash-lite")

def gemini_llm(prompt: str) -> str:
  response = llm.generate_content(prompt, generation_config={"temperature": 0, "max_output_tokens": 20})
  return response.text.strip()

8. Function to get the schema of a DB.

In [11]:
def get_schema(db_id):
  for db in spider_tables:
    if db["db_id"] == db_id:
      table_names = [t.lower() for t in db['table_names_original']]
      cols = db['column_names_original']
      table_columns = {table_name: [] for table_name in table_names}
      for idx, col in cols:
        if idx != -1:
          table_columns[table_names[idx]].append(col)
      return {"tables": table_names, "table_columns": table_columns}
  return None

9. Guided Token Generation over 20 iterations

In [12]:
guided_token_generator = GuidedTokenGenerator(
  pii_tokens={"ssn", "email"},
  llm=gemini_llm,
  get_schema=get_schema
)

In [13]:
def GuidedTokenGeneration(user_question, db):
  db_name = guided_token_generator.start(user_question, db)
  print("DB:", db_name)

  attempts = 0

  while not guided_token_generator.finished() and attempts < 20:
    out = guided_token_generator.forward()
    attempts += 1
    print(f"{attempts}: {out}")

  return " ".join(guided_token_generator.generated_tokens)

10. Test with different user questions!

In [14]:
# user_question = "How many heads of the departments are older than 56?"
# user_question = "List the creation year, name and budget of each department."
# user_question = "What is the average number of employees of the departments whose rank is between 10 and 15?"
user_question = "What are the names of the heads who are born outside the California state?"

sql_query = GuidedTokenGeneration(user_question,'department_management')
print("Final SQL Query: ", sql_query)

DB: department_management
LLM Choice:  SELECT
1: {'LLM choice': 'SELECT', 'SQL prefix': ['SELECT']}
LLM Choice:  name
2: {'LLM choice': 'name', 'SQL prefix': ['SELECT', 'name']}
LLM Choice:  ,
3: {'LLM choice': ',', 'SQL prefix': ['SELECT', 'name', ',']}
LLM Choice:  born_state
4: {'LLM choice': 'born_state', 'SQL prefix': ['SELECT', 'name', ',', 'born_state']}
LLM Choice:  FROM
5: {'LLM choice': 'FROM', 'SQL prefix': ['SELECT', 'name', ',', 'born_state', 'FROM']}
LLM Choice:  head
6: {'LLM choice': 'head', 'SQL prefix': ['SELECT', 'name', ',', 'born_state', 'FROM', 'head']}
LLM Choice:  WHERE
7: {'LLM choice': 'WHERE', 'SQL prefix': ['SELECT', 'name', ',', 'born_state', 'FROM', 'head', 'WHERE']}
LLM Choice:  born_state
8: {'LLM choice': 'born_state', 'SQL prefix': ['SELECT', 'name', ',', 'born_state', 'FROM', 'head', 'WHERE', 'born_state']}
LLM Choice:  !=
9: {'LLM choice': '!=', 'SQL prefix': ['SELECT', 'name', ',', 'born_state', 'FROM', 'head', 'WHERE', 'born_state', '!=']}
LLM Choi