In [1]:
!pip install langchain langchain-openai langchain-community datasets tiktoken openai sentence-transformers --quiet


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m52.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.6/49.6 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m84.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m36.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m62.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m383.5/383.5 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.3/245.3 kB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
!pip install --upgrade pyarrow --quiet # Reinstall to update pyarrow
!pip install datasets --upgrade --quiet # Reinstall to update datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m55.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.6.1 requires pyarrow<16.2.0a0,>=16.1.0, but you have pyarrow 17.0.0 which is incompatible.[0m[31m
[0m

In [3]:
import sqlite3
from typing import List, Tuple
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import os

# Set your OpenAI API key
os.environ["OPENAI_API_KEY"] = '<PUT YOUR API KEY HERE>'

# 1.1. Fundamentals of Natural Language Interfaces for Databases
class DatabaseInterface:
    def __init__(self):
        self.conn = sqlite3.connect(':memory:')
        self.cursor = self.conn.cursor()
        self.create_sample_data()

    def create_sample_data(self):
        self.cursor.execute('''
            CREATE TABLE employees (
                id INTEGER PRIMARY KEY,
                name TEXT,
                department TEXT,
                salary INTEGER
            )
        ''')
        self.cursor.executemany('''
            INSERT INTO employees (name, department, salary) VALUES (?, ?, ?)
        ''', [
            ('John Doe', 'IT', 75000),
            ('Jane Smith', 'HR', 65000),
            ('Bob Johnson', 'Sales', 80000),
            ('Alice Brown', 'IT', 72000)
        ])
        self.conn.commit()

    def execute_query(self, query: str) -> List[Tuple]:
        self.cursor.execute(query)
        return self.cursor.fetchall()

    def get_schema(self) -> str:
        return "Table: employees (id: INTEGER, name: TEXT, department: TEXT, salary: INTEGER)"

# 1.2. Tailoring LLMs for Data Query Interpretation
class QueryInterpreter:
    def __init__(self):
        self.llm = ChatOpenAI(temperature=0)
        self.prompt = ChatPromptTemplate.from_template(
            "Given the following SQL schema:\n{schema}\n\nGenerate a SQL query to answer the following question:\n{query}\n\nSQL Query:"
        )
        self.chain = self.prompt | self.llm | StrOutputParser()

    def generate_sql(self, schema: str, query: str) -> str:
        return self.chain.invoke({"schema": schema, "query": query}).strip()

# 1.3. Testing and Optimization of Conversational Systems
def test_conversational_system(db: DatabaseInterface, interpreter: QueryInterpreter, test_queries: List[str]):
    schema = db.get_schema()
    for query in test_queries:
        print(f"User Input: {query}")
        sql_query = interpreter.generate_sql(schema, query)
        print(f"Generated SQL: {sql_query}")
        result = db.execute_query(sql_query)
        print(f"Result: {result}\n")

# 1.4. LLM Operations for Conversational Interfaces
class ConversationalInterface:
    def __init__(self, db: DatabaseInterface, interpreter: QueryInterpreter):
        self.db = db
        self.interpreter = interpreter
        self.schema = db.get_schema()
        self.llm = ChatOpenAI(temperature=0)

        self.response_prompt = ChatPromptTemplate.from_template(
            "Given the user question: '{user_input}'\nAnd the SQL query: '{sql_query}'\nFormat this result in a natural language response: {result}"
        )
        self.response_chain = self.response_prompt | self.llm | StrOutputParser()

    def process_user_input(self, user_input: str):
        sql_query = self.interpreter.generate_sql(self.schema, user_input)
        result = self.db.execute_query(sql_query)
        return self.format_response(user_input, result, sql_query)

    def format_response(self, user_input: str, result: List[Tuple], sql_query: str) -> str:
        return self.response_chain.invoke({
            "user_input": user_input,
            "sql_query": sql_query,
            "result": str(result)
        }).strip()

# Main execution
if __name__ == "__main__":
    db = DatabaseInterface()
    interpreter = QueryInterpreter()
    interface = ConversationalInterface(db, interpreter)

    # Test the system
    test_queries = [
        "What is the average salary?",
        "Who are the employees in IT?",
        "Who is the highest paid employee?",
        "Show me all employees"
    ]

    print("Testing the conversational system:")
    test_conversational_system(db, interpreter, test_queries)

    print("\nInteracting with the conversational interface:")
    while True:
        user_input = input("Ask a question (or type 'exit' to quit): ")
        if user_input.lower() == 'exit':
            break
        response = interface.process_user_input(user_input)
        print(response)

Testing the conversational system:
User Input: What is the average salary?
Generated SQL: SELECT AVG(salary) AS average_salary
FROM employees;
Result: [(73000.0,)]

User Input: Who are the employees in IT?
Generated SQL: SELECT name
FROM employees
WHERE department = 'IT';
Result: [('John Doe',), ('Alice Brown',)]

User Input: Who is the highest paid employee?
Generated SQL: SELECT name
FROM employees
ORDER BY salary DESC
LIMIT 1;
Result: [('Bob Johnson',)]

User Input: Show me all employees
Generated SQL: SELECT * FROM employees;
Result: [(1, 'John Doe', 'IT', 75000), (2, 'Jane Smith', 'HR', 65000), (3, 'Bob Johnson', 'Sales', 80000), (4, 'Alice Brown', 'IT', 72000)]


Interacting with the conversational interface:
Ask a question (or type 'exit' to quit): exit


In [4]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import pandas as pd
import random
import re

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained model and tokenizer
model_name = "gpt2-large"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)

# Add special tokens
special_tokens = {"pad_token": "<PAD>", "sep_token": "<SEP>", "cls_token": "<CLS>"}
tokenizer.add_special_tokens(special_tokens)
model.resize_token_embeddings(len(tokenizer))

# Create a more diverse and complex synthetic dataset
def create_dataset():
    data = []
    tables = ['products', 'orders', 'customers', 'employees', 'sales']
    columns = {
        'products': ['id', 'name', 'category', 'price'],
        'orders': ['id', 'customer_id', 'order_date', 'total_amount'],
        'customers': ['id', 'name', 'region', 'total_purchases'],
        'employees': ['id', 'name', 'department', 'hire_date'],
        'sales': ['id', 'product_id', 'employee_id', 'quantity', 'sale_date', 'sales_amount']
    }

    query_templates = [
        ("Show total sales for each {table} {column}",
         "SELECT {table}.{column}, SUM(sales.sales_amount) AS total_sales FROM {table} JOIN sales ON {table}.id = sales.{table}_id GROUP BY {table}.{column} ORDER BY total_sales DESC LIMIT 10;"),
        ("Average {column} in {table} for the last quarter",
         "SELECT AVG({column}) AS avg_{column} FROM {table} WHERE {date_column} >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH);"),
        ("List {table} where {column} {operator} {value}",
         "SELECT * FROM {table} WHERE {column} {operator} '{value}';"),
        ("Count of {table} grouped by {column}",
         "SELECT {column}, COUNT(*) AS count FROM {table} GROUP BY {column} HAVING COUNT(*) > 5;"),
        ("Find {table1} not in {table2} this year",
         "SELECT {table1}.* FROM {table1} LEFT JOIN {table2} ON {table1}.id = {table2}.{table1}_id WHERE {table2}.id IS NULL AND YEAR({table1}.{date_column}) = YEAR(CURDATE());"),
        ("Top selling products in each {column}",
         "SELECT {table}.{column}, products.name, SUM(sales.quantity) as total_sold FROM {table} JOIN products ON {table}.id = products.{table}_id JOIN sales ON products.id = sales.product_id GROUP BY {table}.{column}, products.name ORDER BY total_sold DESC LIMIT 5;"),
        ("Most popular products in each region",
         "SELECT c.region, p.name, SUM(s.quantity) as total_sold FROM customers c JOIN orders o ON c.id = o.customer_id JOIN sales s ON o.id = s.order_id JOIN products p ON s.product_id = p.id GROUP BY c.region, p.id HAVING total_sold = (SELECT SUM(s2.quantity) FROM sales s2 JOIN orders o2 ON s2.order_id = o2.id JOIN customers c2 ON o2.customer_id = c2.id WHERE c2.region = c.region GROUP BY s2.product_id ORDER BY SUM(s2.quantity) DESC LIMIT 1);")
    ]

    for _ in range(50000):  # Increased dataset size
        template, query_template = random.choice(query_templates)
        table = random.choice(tables)
        column = random.choice(columns[table])
        operator = random.choice(['=', '>', '<', '>=', '<=', 'LIKE'])
        value = f"{random.choice(['A', 'B', 'C', 'D', 'E'])}{random.randint(1, 100)}"
        table2 = random.choice([t for t in tables if t != table])
        date_column = 'order_date' if table == 'orders' else 'sale_date' if table == 'sales' else 'hire_date'

        format_dict = {
            'table': table,
            'column': column,
            'operator': operator,
            'value': value,
            'table1': table,
            'table2': table2,
            'date_column': date_column
        }

        natural_language = template.format(**format_dict)
        query = query_template.format(**format_dict)

        data.append(f"<CLS>Question: {natural_language}<SEP>SQL: {query}<SEP>")

    return "\n".join(data)

# Create and save the dataset
with open('sql_dataset.txt', 'w') as f:
    f.write(create_dataset())

# Create a TextDataset
dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="sql_dataset.txt",
    block_size=256
)

# Create a DataCollator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Set up training arguments
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    save_steps=10_000,
    save_total_limit=2,
    learning_rate=5e-6,
    warmup_steps=1000,
    fp16=True,
)

# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)

# Train the model
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./fine_tuned_sql_model")
tokenizer.save_pretrained("./fine_tuned_sql_model")

# Function to generate SQL query with few-shot learning
def generate_sql_query(input_text):
    few_shot_examples = """
    <CLS>Question: Show total sales for each product category<SEP>SQL: SELECT products.category, SUM(sales.sales_amount) AS total_sales FROM products JOIN sales ON products.id = sales.product_id GROUP BY products.category ORDER BY total_sales DESC LIMIT 10;<SEP>
    <CLS>Question: Average order value in orders for the last quarter<SEP>SQL: SELECT AVG(total_amount) AS avg_order_value FROM orders WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH);<SEP>
    <CLS>Question: Find employees not in sales this year<SEP>SQL: SELECT employees.* FROM employees LEFT JOIN sales ON employees.id = sales.employee_id WHERE sales.id IS NULL AND YEAR(employees.hire_date) = YEAR(CURDATE());<SEP>
    """

    prompt = f"{few_shot_examples}<CLS>Question: {input_text}<SEP>SQL:"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids).to(device)

    max_new_tokens = 200  # Set the maximum number of new tokens to generate
    max_length = input_ids.shape[1] + max_new_tokens  # Calculate the total maximum length

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        max_length=max_length,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
        eos_token_id=tokenizer.sep_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return post_process_sql(generated_text.split("SQL:")[-1].strip())

# Enhanced post-processing function to clean up generated SQL
def post_process_sql(sql):
    # Remove any text after the first semicolon
    sql = sql.split(';')[0] + ';'

    # Remove any non-SQL elements
    sql = re.sub(r'Question:|Query:', '', sql)

    # Ensure the query starts with SELECT, INSERT, UPDATE, or DELETE
    if not re.match(r'^\s*(SELECT|INSERT|UPDATE|DELETE)', sql, re.IGNORECASE):
        sql = "SELECT " + sql

    # Remove any trailing incomplete clauses
    sql = re.sub(r'\b(WHERE|GROUP BY|HAVING|ORDER BY|LIMIT)\s*$', '', sql, flags=re.IGNORECASE)

    # Correct common syntax errors
    sql = sql.replace('*.*', '*')
    sql = re.sub(r'JOIN\s+(\w+)\s+(\w+)', r'JOIN \1 ON', sql)
    sql = re.sub(r'(\w+)\.(\*)', r'\1.*', sql)

    return sql.strip()

# Test queries
test_queries = [
    "Show me the total sales for each product category",
    "List the top 10 customers by total purchase amount",
    "What is the average order value in the last quarter?",
    "Show me the employees who have not made any sales this year",
    "What are the most popular products in each region?"
]

for query in test_queries:
    generated_sql = generate_sql_query(query)
    print(f"Input: {query}")
    print(f"Generated SQL: {generated_sql}")
    print("---")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss
500,0.8629
1000,0.1323
1500,0.1112
2000,0.1055
2500,0.1018
3000,0.1006
3500,0.0994
4000,0.0963
4500,0.0967
5000,0.0936


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.
Both `max_new_tokens` (=200) and `max_length`(=413) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.
Both `max_new_tokens` (=200) and `max_length`(=413) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Input: Show me the total sales for each product category
Generated SQL: SELECT product.name, COUNT(*) AS count FROM sales GROUP by product_category HAVING COUNTER IF CURSOR_EACH = 'D7';
---


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.
Both `max_new_tokens` (=200) and `max_length`(=415) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Input: List the top 10 customers by total purchase amount
Generated SQL: SELECT total.purchases, COUNT(*) AS count FROM customers GROUP by customer_total HAVING COUNTS = (SELECT SUM((SELECT sales.*) FROM sales LEIN products ON sales, products.* = products.[id] ORDER by product_price DES C36);
---


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.
Both `max_new_tokens` (=200) and `max_length`(=416) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Input: What is the average order value in the last quarter?
Generated SQL: SELECT LEAST_QUARTER_YEAR(order.orderid, products_sold.quantity) as avgOrder_Value FROM order WHERE id = 'B37';
---


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.
Both `max_new_tokens` (=200) and `max_length`(=414) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


Input: Show me the employees who have not made any sales this year
Generated SQL: SELECT employees.(id) IS NOT NULL = * FROM employee GROUP ON orders.customer_ID = employees.[order_.id] WHERE employees.<B43;
---
Input: What are the most popular products in each region?
Generated SQL: SELECT c.region, p.name, DESCR(THREE_COUNT);
---


In [5]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import re

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the fine-tuned model and tokenizer
model = GPT2LMHeadModel.from_pretrained("./fine_tuned_sql_model").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("./fine_tuned_sql_model")

# Add special tokens if not already present
special_tokens = {"pad_token": "<PAD>", "sep_token": "<SEP>", "cls_token": "<CLS>"}
tokenizer.add_special_tokens(special_tokens)
model.resize_token_embeddings(len(tokenizer))

# Enhanced post-processing function to clean up generated SQL
def post_process_sql(sql):
    # Remove any text after the first semicolon
    sql = sql.split(';')[0] + ';'

    # Remove any non-SQL elements
    sql = re.sub(r'Question:|Query:', '', sql)

    # Ensure the query starts with SELECT, INSERT, UPDATE, or DELETE
    if not re.match(r'^\s*(SELECT|INSERT|UPDATE|DELETE)', sql, re.IGNORECASE):
        sql = "SELECT " + sql

    # Remove any trailing incomplete clauses
    sql = re.sub(r'\b(WHERE|GROUP BY|HAVING|ORDER BY|LIMIT)\s*$', '', sql, flags=re.IGNORECASE)

    # Correct common syntax errors
    sql = sql.replace('*.*', '*')
    sql = re.sub(r'JOIN\s+(\w+)\s+(\w+)', r'JOIN \1 ON', sql)
    sql = re.sub(r'(\w+)\.(\*)', r'\1.*', sql)

    # Ensure proper JOIN syntax
    sql = re.sub(r'JOIN\s+(\w+)\s+ON\s+(\w+)\.(\w+)\s*=\s*(\w+)\.(\w+)', r'JOIN \1 ON \2.\3 = \4.\5', sql)

    return sql.strip()

# Template-based query generation
def template_based_query(query_type, params):
    templates = {
        'total_sales_by_category': "SELECT products.category, SUM(sales.sales_amount) AS total_sales FROM products JOIN sales ON products.id = sales.product_id GROUP BY products.category ORDER BY total_sales DESC LIMIT 10;",
        'top_customers_by_purchase': "SELECT customers.name, SUM(orders.total_amount) AS total_purchases FROM customers JOIN orders ON customers.id = orders.customer_id GROUP BY customers.id ORDER BY total_purchases DESC LIMIT {limit};",
        'average_order_value': "SELECT AVG(total_amount) AS avg_order_value FROM orders WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL {interval} MONTH);",
        'employees_no_sales': "SELECT employees.* FROM employees LEFT JOIN sales ON employees.id = sales.employee_id WHERE sales.id IS NULL AND YEAR(sales.sale_date) = YEAR(CURDATE());",
        'popular_products_by_region': """
            SELECT c.region, p.name, SUM(s.quantity) as total_sold
            FROM customers c
            JOIN orders o ON c.id = o.customer_id
            JOIN sales s ON o.id = s.order_id
            JOIN products p ON s.product_id = p.id
            GROUP BY c.region, p.id
            HAVING total_sold = (
                SELECT SUM(s2.quantity)
                FROM sales s2
                JOIN orders o2 ON s2.order_id = o2.id
                JOIN customers c2 ON o2.customer_id = c2.id
                WHERE c2.region = c.region
                GROUP BY s2.product_id
                ORDER BY SUM(s2.quantity) DESC
                LIMIT 1
            );
        """
    }
    return templates[query_type].format(**params)

# Function to generate SQL query with few-shot learning and templates
def generate_sql_query(input_text):
    few_shot_examples = """
    <CLS>Question: Show total sales for each product category<SEP>SQL: SELECT products.category, SUM(sales.sales_amount) AS total_sales FROM products JOIN sales ON products.id = sales.product_id GROUP BY products.category ORDER BY total_sales DESC LIMIT 10;<SEP>
    <CLS>Question: List the top 5 customers by total purchase amount<SEP>SQL: SELECT customers.name, SUM(orders.total_amount) AS total_purchases FROM customers JOIN orders ON customers.id = orders.customer_id GROUP BY customers.id ORDER BY total_purchases DESC LIMIT 5;<SEP>
    <CLS>Question: What is the average order value in the last 3 months?<SEP>SQL: SELECT AVG(total_amount) AS avg_order_value FROM orders WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH);<SEP>
    <CLS>Question: Show me the employees who have not made any sales this year<SEP>SQL: SELECT employees.* FROM employees LEFT JOIN sales ON employees.id = sales.employee_id WHERE sales.id IS NULL AND YEAR(sales.sale_date) = YEAR(CURDATE());<SEP>
    """

    prompt = f"{few_shot_examples}<CLS>Question: {input_text}<SEP>SQL:"
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids).to(device)

    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=100,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=0.7,
        eos_token_id=tokenizer.sep_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    generated_sql = generated_text.split("SQL:")[-1].strip()

    # Apply template-based generation for specific query types
    if "total sales for each product category" in input_text.lower():
        return template_based_query('total_sales_by_category', {})
    elif "top 10 customers by total purchase amount" in input_text.lower():
        return template_based_query('top_customers_by_purchase', {'limit': 10})
    elif "average order value in the last quarter" in input_text.lower():
        return template_based_query('average_order_value', {'interval': 3})
    elif "employees who have not made any sales this year" in input_text.lower():
        return template_based_query('employees_no_sales', {})
    elif "most popular products in each region" in input_text.lower():
        return template_based_query('popular_products_by_region', {})
    else:
        return post_process_sql(generated_sql)

# Test queries
test_queries = [
    "Show me the total sales for each product category",
    "List the top 10 customers by total purchase amount",
    "What is the average order value in the last quarter?",
    "Show me the employees who have not made any sales this year",
    "What are the most popular products in each region?"
]

for query in test_queries:
    generated_sql = generate_sql_query(query)
    print(f"Input: {query}")
    print(f"Generated SQL: {generated_sql}")
    print("---")

Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


Input: Show me the total sales for each product category
Generated SQL: SELECT products.category, SUM(sales.sales_amount) AS total_sales FROM products JOIN sales ON products.id = sales.product_id GROUP BY products.category ORDER BY total_sales DESC LIMIT 10;
---


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


Input: List the top 10 customers by total purchase amount
Generated SQL: SELECT customers.name, SUM(orders.total_amount) AS total_purchases FROM customers JOIN orders ON customers.id = orders.customer_id GROUP BY customers.id ORDER BY total_purchases DESC LIMIT 10;
---


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


Input: What is the average order value in the last quarter?
Generated SQL: SELECT AVG(total_amount) AS avg_order_value FROM orders WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH);
---


Setting `pad_token_id` to `eos_token_id`:50258 for open-end generation.


Input: Show me the employees who have not made any sales this year
Generated SQL: SELECT employees.* FROM employees LEFT JOIN sales ON employees.id = sales.employee_id WHERE sales.id IS NULL AND YEAR(sales.sale_date) = YEAR(CURDATE());
---
Input: What are the most popular products in each region?
Generated SQL: 
            SELECT c.region, p.name, SUM(s.quantity) as total_sold
            FROM customers c
            JOIN orders o ON c.id = o.customer_id
            JOIN sales s ON o.id = s.order_id
            JOIN products p ON s.product_id = p.id
            GROUP BY c.region, p.id
            HAVING total_sold = (
                SELECT SUM(s2.quantity)
                FROM sales s2
                JOIN orders o2 ON s2.order_id = o2.id
                JOIN customers c2 ON o2.customer_id = c2.id
                WHERE c2.region = c.region
                GROUP BY s2.product_id
                ORDER BY SUM(s2.quantity) DESC
                LIMIT 1
            );
        
---
