This notebook demonstrates how to construct a text-to-SQL pipeline using the Gemini model and the BIRD benchmark.


# Setup: Example input query and DB Schema


### Download DB and setup


In [None]:
# first download the sqlite file for the db
import gdown
import os
import json

def load_json_lines(filename):
    data = []
    with open(filename, 'r') as file:
        for line_number, line in enumerate(file, start=1):
            line = line.strip()  # Remove leading/trailing whitespace
            if line:  # Skip empty lines
                try:
                    tmp = json.loads(line)
                    data.append(tmp)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON on line {line_number}: {e}")
                    raise e
    return data

if not os.path.exists("./debit_card_specializing.sqlite"):
  gdown.download(id="1qeVvDWz63bUkNj9V3msxYymcGJ7V3DR9", output="debit_card_specializing.sqlite", quiet=False)

# download some examples
if not os.path.exists("./cars_mschema.txt"):
  gdown.download(id="1Iv-uNHJWSlA5ZJKUYkko4Zcf1uFPVwWp", output="cars_mschema.txt", quiet=False)

if not os.path.exists("./debit_card_queries.jsonl"):
  gdown.download(id="159bSO1vUlStYrn1EnDfK7ZgvmrJv3enb", output="debit_card_queries.jsonl", quiet=False)
assert os.path.exists("./debit_card_specializing.sqlite"), f"Cannot find the sqlite file in cwd. Something must have gone wrong!"

input_data = load_json_lines("debit_card_queries.jsonl")

## Evaluation: Execution Accuracy
A standard metric in BIRD is execution accuracy: given a gold SQL for a user query on a DB, a predicted SQL query is correct when the execution results of the gold SQL and predicted SQL is the same, potentially ignoring columns/rows ordering.

In [None]:
import sqlite3

def execute_sql(sql, db_path, timeout=15):
  with sqlite3.connect(db_path, timeout=timeout) as conn:
    conn.execute("PRAGMA busy_timeout = {} ;".format(timeout * 1000))
    cur = conn.execute(sql)
    return cur.fetchall()


def compare_outputs(gold_output, pred_output, use_set=True):
  gold_output = [tuple(round(x, 2) if isinstance(x,  float) else x for x in row) for row in gold_output]
  pred_output = [tuple(round(x, 2) if isinstance(x, float) else x for x in row) for row in pred_output]
  if use_set:
    gold_output = set(gold_output)
    pred_output = set(pred_output)
  if len(gold_output) != len(pred_output):
    return False
  # round all numbers (float, int) to 2 decimal places
  for gold_row, pred_row in zip(gold_output, pred_output):
    if gold_row != pred_row:
      return False
  return True

def execution_accuracy(gold_sql, pred_sql, db_path):
  gold_output = execute_sql(gold_sql, db_path)
  try:
      pred_output = execute_sql(pred_sql, db_path)
  except:
      print(f"Error executing SQL: {pred_sql}")
      pred_output = []
  return {
      "gold_output": gold_output,
      "pred_output": pred_output,
      "correct": compare_outputs(gold_output, pred_output)
  }

# The importance of context

## Simple DDL solution

A Data Definition Language (DDL) is a subset of SQL that defines and manages database schema structures through statements like CREATE, ALTER, and DROP, specifying tables, columns, data types, constraints, indexes, and relationships within a database. We can extract these information from the database directly to serve as an initial context for the model.

In [None]:
def get_ddl_description(db_path):
  conn = sqlite3.connect(db_path)
  cursor = conn.cursor()

  # Get list of all tables
  cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
  tables = cursor.fetchall()

  ddl_description = ""
  for table_name in tables:
      table = table_name[0]
      cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';")
      ddl = cursor.fetchone()[0]
      ddl_description += str(ddl) + "\n"
  return ddl_description


db_path = "debit_card_specializing.sqlite"
ddl_description = get_ddl_description(db_path)

print(ddl_description)

CREATE TABLE customers
(
    CustomerID INTEGER UNIQUE     not null
        primary key,
    Segment    TEXT null,
    Currency   TEXT null
)
CREATE TABLE gasstations
(
    GasStationID INTEGER    UNIQUE   not null
        primary key,
    ChainID      INTEGER          null,
    Country      TEXT null,
    Segment      TEXT null
)
CREATE TABLE products
(
    ProductID   INTEGER   UNIQUE      not null
        primary key,
    Description TEXT null
)
CREATE TABLE "transactions_1k"
(
    TransactionID INTEGER
        primary key autoincrement,
    Date          DATE,
    Time          TEXT,
    CustomerID    INTEGER,
    CardID        INTEGER,
    GasStationID  INTEGER,
    ProductID     INTEGER,
    Amount        INTEGER,
    Price         REAL
)
CREATE TABLE sqlite_sequence(name,seq)
CREATE TABLE "yearmonth"
(
    CustomerID  INTEGER not null
        references customers
            on update cascade on delete cascade
        references customers,
    Date        TEXT    not null,
    C

## More informative context with mSchema

[M-Schema](https://github.com/XGenerationLab/M-Schema) is an attempt at creating a more informative context on top of a more LLM-friendly presentation of the database’s schema. The key idea behind mSchema is leveraging SQLAlchemy’s reflection to provide connections between tables by including foreign key relationships between tables as well as including examples for each column for improving the model's comprehension. The example below shows how representative examples are added to each column on top of the column’s name and type.


In [None]:
!git clone https://github.com/XGenerationLab/M-Schema.git mschema
!pip install -r /content/mschema/requirements.txt

fatal: destination path 'mschema' already exists and is not an empty directory.


In [None]:
import sys
sys.path.append('/content/mschema/')

In [None]:
from sqlalchemy import create_engine
from schema_engine import SchemaEngine

db_path = "debit_card_specializing.sqlite"
db_name = "debit_card_specializing"
abs_path = os.path.abspath(db_path)
db_engine = create_engine(f'sqlite:///{abs_path}')
schema_engine = SchemaEngine(engine=db_engine, db_name=db_name)
mschema = schema_engine.mschema
mschema_str = mschema.to_mschema()
print(mschema_str)


【DB_ID】 debit_card_specializing
【Schema】
# Table: main.customers
[
(CustomerID:INTEGER, Primary Key, Examples: [3, 5, 6]),
(Segment:TEXT, Examples: [SME, LAM, KAM]),
(Currency:TEXT, Examples: [EUR, CZK])
]
# Table: main.gasstations
[
(GasStationID:INTEGER, Primary Key, Examples: [44, 45, 46]),
(ChainID:INTEGER, Examples: [13, 6, 23]),
(Country:TEXT, Examples: [CZE, SVK]),
(Segment:TEXT, Examples: [Value for money, Premium, Other])
]
# Table: main.products
[
(ProductID:INTEGER, Primary Key, Examples: [1, 2, 3]),
(Description:TEXT, Examples: [Rucní zadání, Nafta, Special])
]
# Table: main.transactions_1k
[
(TransactionID:INTEGER, Primary Key, Examples: [1, 2, 3]),
(Date:DATE, Examples: [2012-08-24]),
(Time:TEXT, Examples: [09:41:00, 10:03:00, 13:53:00]),
(CustomerID:INTEGER, Examples: [31543, 46707, 7654]),
(CardID:INTEGER, Examples: [486621, 550134, 684220]),
(GasStationID:INTEGER, Examples: [3704, 656, 741]),
(ProductID:INTEGER, Examples: [2, 23, 5]),
(Amount:INTEGER, Examples: [28, 18

# Generate with Gemini given the context

Now that we have the context for the databases, let's compare how Gemini performs on each DB context.

To use the Gemini model in this section, you'll need to obtain an API key from Google AI Studio. Follow these steps:

1. Get your API key: Follow [these instructions](https://ai.google.dev/gemini-api/docs/api-key) to create and manage your Gemini API keys
2. Set up authentication: Replace "YOUR-GEMINI-API-KEY-HERE" with your actual API key



In [None]:
from google.genai import types
import re
from pprint import pprint
from functools import partial
from google import genai

client = genai.Client(
    api_key="YOUR-GEMINI-API-KEY-HERE"
)


# get model's response
def get_response(client, contents, model='gemini-1.5-flash', temperature=0.7):
  system_prompt = "You are a SQLite expert tasked with writing SQL for a given natural language user query. You would be given database information in form of CREATE TABLE statements; External Knowledge which are hints; user natural language query. Your task is to write valid SQLite to answer the user questions for the tables provided."
  return client.models.generate_content(
      model=model,
      config=types.GenerateContentConfig(
        system_instruction=system_prompt,
        max_output_tokens=500,
        temperature=temperature
      ),
      contents=contents
  )


def assemble_prompt(example, db_context):
    user_query = example["question"]
    external_info = example["evidence"]
    additional_instruction = "\nJust output SQL starting with SELECT directly wrapped in a ```sql ``` block."

    request = (
      db_context +
      "\n\n" + f"-- External Knowledge: {external_info}\n" +
      "-- Using valid SQLite and understanding External Knowledge, "
      "answer the following questions for the tables provided above.\n" +
      user_query +
      additional_instruction
    )
    return request

def extract_sql_from_code_block(response_txt):
    pattern = r"```sql\s+(.*?)```"
    matches = re.findall(pattern, response_txt, re.DOTALL | re.IGNORECASE)
    return [match.strip() for match in matches]

def create_content(role, txt):
    # this is specific to how Gemini's models format inputs
    return types.Content(role=role, parts=[types.Part.from_text(text=txt)])

def update_request_with_history(history_msg: list, request: str) -> list:
    return history_msg + [create_content("user", request)]

def truncate_str(s, truncate_len):
    if len(s) > truncate_len:
        return s[:truncate_len] + " (truncated...)"
    return s

def print_result(result_dict: dict, truncate_length=60):
    # pretty print result_dict with truncation
    result = ""
    for k, v in result_dict.items():
        result += f"{k}: {truncate_str(str(v), truncate_length)}\n"
    print(result)

def evaluate_method(input_data, get_prompt, max_iter=None, history_msg=None):
    """
    Given
      `input_data` a list of queries and gold SQL,
      `get_prompt` a function that takes in an example and returns a prompt,
      `history_msg` a list of messages to be sent to the model (use for fewshot examples),
    evaluate it on Gemini.
    """
    total = 0
    correct_count = 0

    question_stats = {}
    for example_id in range(len(input_data)):
      example = input_data[example_id]

      request = get_prompt(example)
      if history_msg is not None:
        request = update_request_with_history(history_msg, request)
      # send message to model: optionally attach history_msg
      response = get_response(client, request)

      pred_sql = extract_sql_from_code_block(response.text)
      gold_sql = example["SQL"]

      print(f"\n**Question {example_id+1}**\n\t{example['question']}")

      result = execution_accuracy(gold_sql, pred_sql[0], db_path)
      is_correct = result['correct']
      correct_count += is_correct
      total += 1
      print_result(result)
      question_stats[example_id] = {"result": is_correct, "pred_sql": pred_sql[0]}
      if max_iter is not None and total > max_iter:
        break

    print(f"\nAccuracy: {correct_count / total}")
    return question_stats

## Evaluation: DDL + Gemini

We should get around 54.7% accuracy on the `debit_card_specialization` subset.

In [None]:
get_basic_ddl_prompt = partial(assemble_prompt, db_context=ddl_description)

results = evaluate_method(input_data, get_basic_ddl_prompt)


**Question 1**
	How many gas stations in CZE has Premium gas?
gold_output: [(1114,)]
pred_output: [(0,)]
correct: False


**Question 2**
	What is the ratio of customers who pay in EUR against customers who pay in CZK?
gold_output: [(0.06572769953051644,)]
pred_output: [(0.06572769953051644,)]
correct: True


**Question 3**
	In 2012, who had the least consumption in LAM?
gold_output: [(47273,)]
pred_output: [(7653,)]
correct: False


**Question 4**
	What was the average monthly consumption of customers in SME for the year 2013?
gold_output: [(459.9562642871061,)]
pred_output: [(459.9562642871061,)]
correct: True


**Question 5**
	Which customers, paying in CZK, consumed the most gas in 2011?
gold_output: [(603,)]
pred_output: [(603,)]
correct: True



KeyboardInterrupt: 

## Evaluation: mSchema + Gemini

We should get an improved accuracy of around 60.9% on the same subset with mSchema as the context.

In [None]:
get_basic_mschema_prompt = partial(assemble_prompt, db_context=mschema_str)

results = evaluate_method(input_data, get_basic_mschema_prompt)


**Question 1**
	How many gas stations in CZE has Premium gas?
gold_output: [(1114,)]
pred_output: [(1114,)]
correct: True


**Question 2**
	What is the ratio of customers who pay in EUR against customers who pay in CZK?
gold_output: [(0.06572769953051644,)]
pred_output: [(0.06572769953051644,)]
correct: True


**Question 3**
	In 2012, who had the least consumption in LAM?
gold_output: [(47273,)]
pred_output: [(7653,)]
correct: False


**Question 4**
	What was the average monthly consumption of customers in SME for the year 2013?
gold_output: [(459.9562642871061,)]
pred_output: [(459.9562642871061,)]
correct: True


**Question 5**
	Which customers, paying in CZK, consumed the most gas in 2011?
gold_output: [(603,)]
pred_output: [(603,)]
correct: True


**Question 6**
	How many customers in KAM had a consumption of less than 30,000 for the year 2012?
gold_output: [(1123,)]
pred_output: [(1746,)]
correct: False


**Question 7**
	What was the difference in gas consumption between CZK-payi

# Adding few shot examples

Demonstration can be useful.

In [None]:
cars_mschema = open("cars_mschema.txt", "r").read()
fewshot_examples = [
    {
        "db_id": "cars",
        "question": "List the car's name with a price worth greater than 85% of the average price of all cars.",
        "evidence": "car's name refers to car_name; a price worth greater than 85% of the average price of all cars refers to price > multiply(avg(price), 0.85)",
        "SQL": "SELECT T1.car_name FROM data AS T1 INNER JOIN price AS T2 ON T1.ID = T2.ID WHERE T2.price * 100 > ( SELECT AVG(price) * 85 FROM price )",
    },
    {
        "db_id": "cars",
        "question": "Calculate the average production rate per year from 1971 to 1980. Among them, name the cars with a weight of fewer than 1800 lbs.",
        "evidence": "from 1971 to 1980 refers to model_year between 1971 and 1980; average production rate per year = divide(count(ID where model_year between 1971 and 1980), 9); car's name refers to car_name; a weight of fewer than 1800 lbs refers to weight < 1800",
        "SQL": "SELECT CAST(COUNT(T1.ID) AS REAL) / 9 FROM production AS T1 INNER JOIN data AS T2 ON T2.ID = T1.ID WHERE T1.model_year BETWEEN 1971 AND 1980 UNION ALL SELECT DISTINCT T2.car_name FROM production AS T1 INNER JOIN data AS T2 ON T2.ID = T1.ID WHERE T1.model_year BETWEEN 1971 AND 1980 AND T2.weight < 1800",
    }
]

def get_fewshot_prompt(index):
    assert index < len(fewshot_examples), f"Index {index} out of range"
    fewshot_example = fewshot_examples[index]
    history_msg = [
        create_content("user", assemble_prompt(fewshot_example, cars_mschema)),
        create_content("model", f"```sql\n{fewshot_example['SQL']}\n```")
    ]
    return history_msg


get_basic_mschema_prompt = partial(assemble_prompt, db_context=mschema_str)

fewshot_prompt = get_fewshot_prompt(1)
# fewshot_prompt = None
results = evaluate_method(input_data, get_basic_mschema_prompt, history_msg=fewshot_prompt)



**Question 1**
	How many gas stations in CZE has Premium gas?
{'correct': True, 'gold_output': [(1114,)], 'pred_output': [(1114,)]}

**Question 2**
	What is the ratio of customers who pay in EUR against customers who pay in CZK?
{'correct': True,
 'gold_output': [(0.06572769953051644,)],
 'pred_output': [(0.06572769953051644,)]}

**Question 3**
	In 2012, who had the least consumption in LAM?
{'correct': False, 'gold_output': [(47273,)], 'pred_output': [(7653,)]}

**Question 4**
	What was the average monthly consumption of customers in SME for the year 2013?
{'correct': False,
 'gold_output': [(459.9562642871061,)],
 'pred_output': [(5519.475171445273,)]}

**Question 5**
	Which customers, paying in CZK, consumed the most gas in 2011?
{'correct': True, 'gold_output': [(603,)], 'pred_output': [(603,)]}

**Question 6**
	How many customers in KAM had a consumption of less than 30,000 for the year 2012?
{'correct': False, 'gold_output': [(1123,)], 'pred_output': [(11139,)]}

**Question 7**
