## Load & Preprocess Dataset

You must download the train dataset from [BIRD-bench](https://bird-bench.github.io/).

### Prepare utility functions

In [None]:
import sqlite3
import json
import time
import numpy as np

db_table_map = {
    "works_cycles": [
        "CountryRegion",
        "Culture",
        "Currency",
        "CountryRegionCurrency",
        "Person",
        "BusinessEntityContact",
        "EmailAddress",
        "Employee",
        "Password",
        "PersonCreditCard",
        "ProductCategory",
        "ProductDescription",
        "ProductModel",
        "ProductModelProductDescriptionCulture",
        "ProductPhoto",
        "ProductSubcategory",
        "SalesReason",
        "SalesTerritory",
        "SalesPerson",
        "SalesPersonQuotaHistory",
        "SalesTerritoryHistory",
        "ScrapReason",
        "Shift",
        "ShipMethod",
        "SpecialOffer",
        "BusinessEntityAddress",
        "SalesTaxRate",
        "Store",
        "SalesOrderHeaderSalesReason",
        "TransactionHistoryArchive",
        "UnitMeasure",
        "ProductCostHistory",
        "ProductDocument",
        "ProductInventory",
        "ProductProductPhoto",
        "ProductReview",
        "ShoppingCartItem",
        "SpecialOfferProduct",
        "SalesOrderDetail",
        "TransactionHistory",
        "Vendor",
        "ProductVendor",
        "PurchaseOrderHeader",
        "PurchaseOrderDetail",
        "WorkOrder",
        "WorkOrderRouting",
        "Customer",
        "ProductListPriceHistory",
        "Address",
        "AddressType",
        "BillOfMaterials",
        "BusinessEntity",
        "ContactType",
        "CurrencyRate",
        "Department",
        "EmployeeDepartmentHistory",
        "EmployeePayHistory",
        "JobCandidate",
        "Location",
        "PhoneNumberType",
        "Product",
        "Document",
        "StateProvince",
        "CreditCard",
        "SalesOrderHeader"
    ],
}


def nice_look_table(column_names: list, values: list):
    rows = []
    # Determine the maximum width of each column
    widths = [
        max(len(str(value[i])) for value in values + [column_names])
        for i in range(len(column_names))
    ]

    # Print the column names
    header = "".join(
        f"{column.rjust(width)} " for column, width in zip(column_names, widths)
    )
    # print(header)
    # Print the values
    for value in values:
        row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths))
        rows.append(row)
    rows = "\n".join(rows)
    final_output = header + "\n" + rows
    return final_output


def generate_schema_prompt_sqlite(db_path, num_rows=None):
    # extract create ddls
    """
    :param root_place:
    :param db_name:
    :return:
    """
    full_schema_prompt_list = []
    conn = sqlite3.connect(db_path)
    # Create a cursor object
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = cursor.fetchall()
    schemas = {}
    for table in tables:
        if table == "sqlite_sequence":
            continue
        cursor.execute(
            "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(
                table[0]
            )
        )
        create_prompt = cursor.fetchone()[0]
        schemas[table[0]] = create_prompt
        if num_rows:
            cur_table = table[0]
            if cur_table in ["order", "by", "group"]:
                cur_table = "`{}`".format(cur_table)

            cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
            column_names = [description[0] for description in cursor.description]
            values = cursor.fetchall()
            rows_prompt = nice_look_table(column_names=column_names, values=values)
            verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(
                num_rows, cur_table, num_rows, rows_prompt
            )
            schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)

    for k, v in schemas.items():
        full_schema_prompt_list.append(v)

    schema_prompt = "\n\n".join(full_schema_prompt_list)

    return schema_prompt

# Basic Utilities
def load_json(dir):
    with open(dir, "r") as j:
        contents = json.loads(j.read())
    return contents

def connect_db(db_path):
    conn = sqlite3.connect(db_path)
    return conn

def execute_sql(sql, db_path, return_time=False):
    # Connect to the database
    conn = connect_db(db_path)
    start_time = time.time()
    cursor = conn.cursor()
    cursor.execute(sql)
    res = cursor.fetchall()
    conn.close()  # Don't forget to close the connection!
    exec_time = time.time() - start_time
    if return_time:
        return exec_time

    return res

# Calculate exact match
def calculate_exact_match(predicted_res, ground_truth_res):
    res = 0
    if set(predicted_res) == set(ground_truth_res):
        res = 1
    return res


def calculate_row_match(predicted_row, ground_truth_row):
    """
    Calculate the matching percentage for a single row.

    Args:
    predicted_row (tuple): The predicted row values.
    ground_truth_row (tuple): The actual row values from ground truth.

    Returns:
    float: The match percentage (0 to 1 scale).
    """
    total_columns = len(ground_truth_row)
    matches = 0
    element_in_pred_only = 0
    element_in_truth_only = 0
    for pred_val in predicted_row:
        if pred_val in ground_truth_row:
            matches += 1
        else:
            element_in_pred_only += 1
    for truth_val in ground_truth_row:
        if truth_val not in predicted_row:
            element_in_truth_only += 1
    match_percentage = matches / total_columns
    pred_only_percentage = element_in_pred_only / total_columns
    truth_only_percentage = element_in_truth_only / total_columns
    return match_percentage, pred_only_percentage, truth_only_percentage

# Calculate F1 score
def calculate_f1_score(predicted, ground_truth):
    """
    Calculate the F1 score based on sets of predicted results and ground truth results,
    where each element (tuple) represents a row from the database with multiple columns.

    Args:
    predicted (set of tuples): Predicted results from SQL query.
    ground_truth (set of tuples): Actual results expected (ground truth).

    Returns:
    float: The calculated F1 score.
    """
    # if both predicted and ground_truth are empty, return 1.0 for f1_score
    if not predicted and not ground_truth:
        return 1.0

    # Drop duplicates
    predicted_set = set(predicted) if predicted else set()
    ground_truth_set = set(ground_truth)

    # convert back to list
    predicted = list(predicted_set)
    ground_truth = list(ground_truth_set)

    # Calculate matching scores for each possible pair
    match_scores = []
    pred_only_scores = []
    truth_only_scores = []
    for i, gt_row in enumerate(ground_truth):
        # rows only in the ground truth results
        if i >= len(predicted):
            match_scores.append(0)
            truth_only_scores.append(1)
            continue
        pred_row = predicted[i]
        match_score, pred_only_score, truth_only_score = calculate_row_match(
            pred_row, gt_row
        )
        match_scores.append(match_score)
        pred_only_scores.append(pred_only_score)
        truth_only_scores.append(truth_only_score)

    # rows only in the predicted results
    for i in range(len(predicted) - len(ground_truth)):
        match_scores.append(0)
        pred_only_scores.append(1)
        truth_only_scores.append(0)

    tp = sum(match_scores)
    fp = sum(pred_only_scores)
    fn = sum(truth_only_scores)

    precision = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0

    f1_score = (
        2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
    )
    return f1_score

# Caculate verocity efficiency score
def clean_abnormal(input):
    input = np.asarray(input)
    processed_list = []
    mean = np.mean(input, axis=0)
    std = np.std(input, axis=0)
    for x in input:
        if x < mean + 3 * std and x > mean - 3 * std:
            processed_list.append(x)
    return processed_list


def calculate_velocity_efficiency_score(
    predicted_sql, ground_truth, db_path, iterate_num
):
    diff_list = []
    predicted_res = execute_sql(predicted_sql, db_path)
    ground_truth_res = execute_sql(ground_truth, db_path)
    reward = 0
    time_ratio = 0
    if set(predicted_res) == set(ground_truth_res):
        for _ in range(iterate_num):
            predicted_time = execute_sql(
                predicted_sql, db_path, return_time=True
            )
            ground_truth_time = execute_sql(
                ground_truth, db_path, return_time=True
            )
            diff_list.append(ground_truth_time / predicted_time)
        processed_diff_list = clean_abnormal(diff_list)
        time_ratio = sum(processed_diff_list) / len(processed_diff_list)
    if time_ratio == 0:
        reward = 0
    elif time_ratio >= 2:
        reward = 1.25
    elif time_ratio >= 1 and time_ratio < 2:
        reward = 1
    elif time_ratio >= 0.5 and time_ratio < 1:
        reward = 0.75
    elif time_ratio >= 0.25 and time_ratio < 0.5:
        reward = 0.5
    else:
        reward = 0.25
    # return time_ratio
    return reward

### Load Dataset

In [None]:
import json
import random
import matplotlib.pyplot as plt

from typing import Dict, List, Any, Union

from dotenv import load_dotenv
from openai import AsyncOpenAI

load_dotenv()

# get dataset
# For tutorial, we will use openai/gsm8k
trainset = []
with open("train/train.json", "r") as f:
    trainset = json.load(f)

train_tables = []
with open("train/train_tables.json", "r") as f:
    train_tables = json.load(f)

# select 'works_cycles' database
works_cycles_trainset = [item for item in trainset if item["db_id"] == "works_cycles"]
works_cycles_train_tables = [table for table in train_tables if table["db_id"] == "works_cycles"]

# save
with open("train/works_cycles_train.json", "w") as f:
    json.dump(works_cycles_trainset, f)
with open("train/works_cycles_train_tables.json", "w") as f:
    json.dump(works_cycles_train_tables, f)

dataset = []
with open("works_cycles_train.json", "r") as f:
    dataset = json.load(f)

random.seed(42)
random.shuffle(dataset)
trainset = dataset[:100]
testset = dataset[100:200]

table_schema = generate_schema_prompt_sqlite("works_cycles/works_cycles.sqlite")

from ape.common.types import DatasetItem


trainset = [DatasetItem(inputs={"question": item["question"], "table_schema": table_schema}, outputs={"answer": item["SQL"]}) for item in trainset]
testset = [DatasetItem(inputs={"question": item["question"], "table_schema": table_schema}, outputs={"answer": item["SQL"]}) for item in testset]


## Prepare Prompt to optimize

In [2]:
from ape.common import Prompt

prompt = """\
Using valid SQLite, answer the following questions for the tables provided above.
Question: {question}
Table Schema: {table_schema}
Generate the SQLite for the above question after thinking step by step.
"""

json_schema = {
    "type": "json_schema", 
    "json_schema": {
        "name": "TextToSQL",
        "strict": True,
        "schema": {
            "type": "object",
            "properties": {
                "thought": {
                    "type": "string",
                    "description": "The reasoning process of the problem solving"
                },
                "answer": {
                    "type": "string",
                    "description": "The answer SQL query for question"
                }
            },
            "required": ["thought", "answer"],
            "additionalProperties": False
        }
    }
}


student_prompt = Prompt(
    messages=[
        {"role": "system", "content": prompt},
    ],
    model="gpt-4o-mini",
    temperature=0.0,
    response_format=json_schema,
    name="SQL Generator",
)

## Prepare Generator, Metric, and Global Metric

In [3]:
from ape.common.generator import BaseGenerator
from ape.common.metric import BaseMetric
from ape.common.global_metric import BaseGlobalMetric
from ape.common.types import MetricResult, GlobalMetricResult

# define generator, metric, global metric

db_path = "works_cycles/works_cycles.sqlite"
openai = AsyncOpenAI()

class BirdBenchSolver(BaseGenerator):
    async def generate(
        self,
        prompt: Prompt,
        inputs: Dict[str, Any],
    ) -> Union[Dict[str, Any], str]:
        retry_count = 0
        messages = prompt.format(**inputs).messages
        model = prompt.model
        response_format = prompt.response_format
        while retry_count < 3:
            try:
                response = await openai.chat.completions.create(
                    model=model,
                    messages=messages,
                    response_format=response_format,
                    temperature=0.0,
                )
                return json.loads(response.choices[0].message.content)
            except Exception as e:
                print(e)
                retry_count += 1
        return {
            "thought": "error",
            "answer": "",
        }


class BirdBenchMetric(BaseMetric):
    async def compute(
        self,
        dataset_item: DatasetItem,
        pred: Dict[str, Any],
    ) -> MetricResult:
        try:
            pred_sql = pred["answer"]
            ground_truth_sql = dataset_item["outputs"]["answer"]
            
            pred_rows = execute_sql(pred_sql, db_path)
            ground_truth_rows = execute_sql(ground_truth_sql, db_path)
            
            exact_match_score = calculate_exact_match(pred_rows, ground_truth_rows)
            f1_score = calculate_f1_score(pred_rows, ground_truth_rows)
            # ves = calculate_velocity_efficiency_score(pred_sql, ground_truth_sql, db_path, 10)
            
            # score = (exact_match_score + f1_score + ves) / 3
            score = (exact_match_score + f1_score) / 2
            return MetricResult(
                score=score,
                trace={
                    "f1_score": f1_score,
                    "exact_match_score": exact_match_score,
                    # "velocity_efficiency_score": ves
                }
            )
            
        except Exception as e:
            print(e)
            # print(pred)
            return MetricResult(
                score=0.0,
                trace={
                    # "velocity_efficiency_score": 0.0,
                    "f1_score": 0.0,
                    "exact_match_score": 0.0
                }
            )

class GlobalAverageMetric(BaseGlobalMetric):
    async def compute(
        self,
        results: List[MetricResult],
    ) -> GlobalMetricResult:
        try:
            scores = [result.score for result in results]
            exact_match_scores = [result.trace["exact_match_score"] for result in results]
            f1_scores = [result.trace["f1_score"] for result in results]
            # ves_scores = [result.trace["velocity_efficiency_score"] for result in results]    
            return GlobalMetricResult(
                score=sum(scores) / len(scores) if len(results) > 0 else 0.0,
                trace={
                    "exact_match_scores": sum(exact_match_scores) / len(exact_match_scores),
                    "f1_scores": sum(f1_scores) / len(f1_scores),
                    # "ves_scores": sum(ves_scores) / len(ves_scores)
                }
            )
        except Exception as e:
            # print("Error in GlobalEmotionMetric: ", e)
            return GlobalMetricResult(
                score=0.0,
                trace={
                    "exact_match_scores": 0.0,
                    "f1_scores": 0.0,
                    # "ves_scores": 0.0
                }
            )

## Select Trainer & Run

In [None]:
from ape.core.trainer import (
    TextGradientTrainer,
    ExpelTrainer,
    FewShotTrainer,
    EvoPromptTrainer,
    DspyMiproTrainer,
    OptunaTrainer,
)

# define trainer 
trainer = FewShotTrainer(
    generator=BirdBenchSolver(),
    metric=BirdBenchMetric(),
    global_metric=GlobalAverageMetric(),
    testmode=True # If True, trainer will run prompts for validation set and save results.
)

# run trainer
optimized_prompt, report = await trainer.train(
    prompt=student_prompt,
    trainset=trainset,
    valset=testset,
)


## Print Optimized Prompt

In [None]:
# print optimized prompt
for message in optimized_prompt.messages:
    print(message)

## Print Benchmark Test Results

In [None]:
# visualize experiment results
def visualize_scores(report):
    scores = report.scores
    trainset_scores = [score["score"] for score in scores]
    valset_scores = [score["val_score"] for score in scores]
    iterations = range(1, len(trainset_scores) + 1)

    plt.figure(figsize=(10, 6))
    plt.plot(iterations, trainset_scores, label='Training Set', marker='o')
    plt.plot(iterations, valset_scores, label='Validation Set', marker='s')
    
    for i, (train_score, val_score) in enumerate(zip(trainset_scores, valset_scores)):
        plt.text(iterations[i], train_score, f'{train_score:.2f}', 
                    ha='center', va='bottom', fontsize=8, color='blue')
        plt.text(iterations[i], val_score, f'{val_score:.2f}', 
                    ha='center', va='bottom', fontsize=8, color='green')

    plt.title('Training and Validation Scores over Iterations')
    plt.xlabel('Iteration')
    plt.ylabel('Score')
    plt.legend()
    plt.show()

visualize_scores(report)