In [None]:
from datasets import load_dataset

import re
from os import environ
from dotenv import load_dotenv
from distilabel.models.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration

load_dotenv()

In [None]:
dataset = load_dataset("gretelai/synthetic_text_to_sql", split="train").select(range(10))

In [None]:
BASE_URL = "https://api.openai.com/v1"
API_KEY = environ.get("OPENAI_API_KEY")
MODEL = "gpt-4.1-nano"
GENERATION_KWARGS = {
    "temperature": 0.2,
    "max_new_tokens": 1024,
}
PROMPT_TEMPLATE = """\
You are an expert in writing optimized SQL queries, with strong logical reasoning skills.

Follow these guidelines:
- Use <think></think> tags to explain your reasoning process step by step.
- Use <sql></sql> tags to present the final SQL query.

Instruction:
{{ instruction }}

Context:
{{ context }}

Explanation:
{{ explanation }}

Expected Output:
<think>
1. Interpret the instruction and context to determine the requirements.
2. Identify relevant tables, columns, filters, and relationships.
3. Formulate a clear and efficient SQL query that meets the requirements.
4. Optimize the query for performance and accuracy.
5. Briefly explain the logic behind the query.
6. Ensure the output matches both instruction and context.
</think>
<sql>
Final SQL query goes here
</sql>""".rstrip()
PROMPT_COLUMN = "sql_prompt"
CONTEXT_COLUMN = "sql_context"
EXPLANATION_COLUMN = "sql_explanation"
INPUT_BATCH_SIZE = 16
NUM_GENERATIONS = 1

In [None]:
with Pipeline() as pipeline:
    TextGeneration(
        llm=OpenAILLM(
            base_url=BASE_URL,
            api_key=API_KEY,
            model=MODEL,
            generation_kwargs=GENERATION_KWARGS,
        ),
        template=PROMPT_TEMPLATE,
        columns=[
            "instruction",
            "context",
            "explanation"
        ],
        input_mappings={
            "instruction": PROMPT_COLUMN,
            "context": CONTEXT_COLUMN,
            "explanation": EXPLANATION_COLUMN
        },
        input_batch_size=INPUT_BATCH_SIZE,
        num_generations=NUM_GENERATIONS,
        group_generations=True,
    )

In [None]:
def extract_sql(text: str) -> str | None:
    """
    Extracts the SQL using regex from the generated text.
    :param text:
    :return:
    """
    sql_match = re.search(r"<sql>(.*?)</sql>", text, re.DOTALL)
    if sql_match:
        return sql_match.group(1).strip()

def extract_think(text: str) -> str | None:
    """
    Extracts the think using regex from the generated text.
    :param text:
    :return:
    """
    think_match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
    if think_match:
        return think_match.group(1).strip()

In [None]:
distiset = pipeline.run(dataset=dataset)

In [None]:
COLORED_GREEN = "\033[92m"
COLORED_BLUE = "\033[94m"
COLORED_RESET = "\033[0m"

for generate in distiset['default']['train']['generation']:
    _next = next(iter(generate))
    think = extract_think(_next)
    sql = extract_sql(_next)
    print(f"Think: {COLORED_BLUE}{think}{COLORED_RESET}")
    print(f"SQL: {COLORED_GREEN}{sql}{COLORED_RESET}")
    print("-" * 20)