Create a an LLM pipeline that will transform any free text query into a SQL query, the key points of this task are:

* Create a valid representation of SQL tables allowing for semantic search that will match the top results to the given free text query

* Based on the table representation the LLM has to create a real SQL query, based on free text user query, that will allow for immediate usage

* LLM has to support creation of queries with different levels of complexity not only the simplest ones

* LLM has to support creating queries to fetch data from different database schemas

* When an error in LLM created sql query is encountered it should attempt to self correct

In [1]:
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import re
import pandas as pd
import torch
import numpy as np

In [2]:
llm = LlamaCpp(model_path="sqlcoder2-GGUF\sqlcoder2.Q8_0.gguf",
               n_batch=512,
               n_ctx=2048,
               n_gpu_layers=30,
               verbose=True)

llama_model_loader: loaded meta data with 19 key-value pairs and 485 tensors from sqlcoder2-GGUF\sqlcoder2.Q8_0.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = starcoder
llama_model_loader: - kv   1:                               general.name str              = StarCoder
llama_model_loader: - kv   2:                   starcoder.context_length u32              = 8192
llama_model_loader: - kv   3:                 starcoder.embedding_length u32              = 6144
llama_model_loader: - kv   4:              starcoder.feed_forward_length u32              = 24576
llama_model_loader: - kv   5:                      starcoder.block_count u32              = 40
llama_model_loader: - kv   6:             starcoder.attention.head_count u32              = 48
llama_model_loader: - kv   7:          starcoder.attention.head_count_kv u32     

In [9]:
ddl_statements = """
-- Create schemas
CREATE SCHEMA public;
CREATE SCHEMA sales;
CREATE SCHEMA analytics;

-- Create tables in the public schema
CREATE TABLE public.employees (
    employee_id INTEGER PRIMARY KEY,
    first_name VARCHAR,
    last_name VARCHAR,
    age INTEGER,
    department_id INTEGER,
    hire_date DATE,
    FOREIGN KEY (department_id) REFERENCES public.departments(department_id)
);

CREATE TABLE public.departments (
    department_id INTEGER PRIMARY KEY,
    department_name VARCHAR,
    manager_id INTEGER,
    FOREIGN KEY (manager_id) REFERENCES public.employees(employee_id)
);

CREATE TABLE public.salaries (
    employee_id INTEGER PRIMARY KEY,
    salary_amount DECIMAL,
    effective_date DATE,
    FOREIGN KEY (employee_id) REFERENCES public.employees(employee_id)
);

-- Create tables in the sales schema
CREATE TABLE sales.orders (
    order_id INTEGER PRIMARY KEY,
    order_date DATE,
    customer_id INTEGER,
    sales_rep_id INTEGER,
    FOREIGN KEY (customer_id) REFERENCES sales.customers(customer_id),
    FOREIGN KEY (sales_rep_id) REFERENCES sales.sales_reps(sales_rep_id)
);

CREATE TABLE sales.customers (
    customer_id INTEGER PRIMARY KEY,
    customer_name VARCHAR,
    contact_number VARCHAR
);

CREATE TABLE sales.sales_reps (
    sales_rep_id INTEGER PRIMARY KEY,
    first_name VARCHAR,
    last_name VARCHAR,
    region VARCHAR
);

CREATE TABLE sales.products (
    product_id INTEGER PRIMARY KEY,
    product_name VARCHAR,
    price DECIMAL
);

-- Create tables in the analytics schema
CREATE TABLE analytics.sales_reports (
    report_id INTEGER PRIMARY KEY,
    report_date DATE,
    total_sales DECIMAL,
    region VARCHAR
);

CREATE TABLE analytics.customer_metrics (
    metric_id INTEGER PRIMARY KEY,
    customer_id INTEGER,
    lifetime_value DECIMAL,
    average_order_value DECIMAL,
    FOREIGN KEY (customer_id) REFERENCES sales.customers(customer_id)
);

CREATE TABLE analytics.product_performance (
    performance_id INTEGER PRIMARY KEY,
    product_id INTEGER,
    sales_quantity INTEGER,
    revenue_generated DECIMAL,
    FOREIGN KEY (product_id) REFERENCES sales.products(product_id)
);
"""

In [10]:
template = """
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
{instructions}

DDL statements:
{ddl_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""

In [28]:
verification_template = """
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Verify if this SQL query correctly answers the question: {user_question}.
SQL query: {sql_query}
If yes, return the same query. If not return corrected query.

DDL statements:
{ddl_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""

In [11]:
user_question = "Select the name and age of employees who joined after January 2020"
instructions = """
If the question does not match the existing tables, return 'The query does not match any existing tables. Please check the table names or columns and try again.'
When an error in your sql query is encountered, attempt to correct it.
"""

In [12]:
prompt = PromptTemplate(template=template, input_variables=["user_question", "instructions", "ddl_statements"])

In [29]:
verification_prompt = PromptTemplate(template=verification_template, input_variables=["user_question", "sql_query", "ddl_statements"])

In [19]:
def get_sql_query_from_llm(prompt, llm, user_question, instructions, ddl_statements):
    llm_chain = prompt | llm
    raw_llm_answer = llm_chain.invoke({"user_question": user_question, "instructions": instructions, "ddl_statements": ddl_statements})
    return raw_llm_answer
    

In [30]:
# Function to verify and correct SQL queries
def verify_and_correct_sql(sql_query, user_question, ddl_statements):
    llm_chain = verification_prompt | llm
    raw_llm_answer = llm_chain.invoke({"user_question": user_question, "sql_query": sql_query, "ddl_statements": ddl_statements})
    return raw_llm_answer

In [31]:
sql_query = get_sql_query_from_llm(prompt, llm, user_question, instructions, ddl_statements)
print(sql_query)

Llama.generate: prefix-match hit

llama_print_timings:        load time =   25295.80 ms
llama_print_timings:      sample time =      37.15 ms /   256 runs   (    0.15 ms per token,  6890.61 tokens per second)
llama_print_timings: prompt eval time =   39685.18 ms /   717 tokens (   55.35 ms per token,    18.07 tokens per second)
llama_print_timings:        eval time =   90874.47 ms /   255 runs   (  356.37 ms per token,     2.81 tokens per second)
llama_print_timings:       total time =  131057.76 ms /   972 tokens


SELECT e.first_name, e.last_name, e.age FROM analytics.customer_metrics cm JOIN sales.customers sc ON cm.customer_id = sc.customer_id JOIN sales.orders so ON cm.customer_id = so.customer_id JOIN sales.sales_reps sr ON so.sales_rep_id = sr.sales_rep_id JOIN sales.products sp ON so.product_id::integer = sp.product_id::integer JOIN analytics.product_performance pp ON sp.product_id::integer = pp.performance_id JOIN sales.departments sd ON cm.customer_id::integer::TEXT || sd.department_id::text::VARCHAR::text::integer::text::date::to_char::TEXT::timestamp::epoch::int::TEXT::timestamp::day::to_char::TEXT::timestamp::day::int::varchar::to_char::TEXT::timestamp::month::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to


In [32]:
corrected_sql_query = verify_and_correct_sql(sql_query, user_question, ddl_statements)
print(corrected_sql_query)

Llama.generate: prefix-match hit

llama_print_timings:        load time =   25295.80 ms
llama_print_timings:      sample time =      38.63 ms /   256 runs   (    0.15 ms per token,  6626.80 tokens per second)
llama_print_timings: prompt eval time =   53464.88 ms /   920 tokens (   58.11 ms per token,    17.21 tokens per second)
llama_print_timings:        eval time =  103549.55 ms /   255 runs   (  406.08 ms per token,     2.46 tokens per second)
llama_print_timings:       total time =  157357.06 ms /  1175 tokens


SELECT e.first_name, e.last_name, e.age FROM sales.customers sc JOIN sales.orders so ON sc.customer_id::integer::text::varchar::to_char::TEXT::timestamp::epoch::int::TEXT::timestamp::day::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::month::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::day::int::varchar::to_char::TEXT::timestamp::month::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int::varchar::to_char::TEXT::timestamp::year::int


In [None]:
def generate_sql_query(natural_language_query):
    try:
        sql_query = llm.generate(natural_language_query)
        return sql_query
    except Exception as e:
        print(f"Error generating SQL query: {e}")
        return None

def sql_pipeline(natural_language_query):
    sql_query = generate_sql_query(natural_language_query)
    return sql_query

# Example usage
nl_query = "Select the name and age of employees who joined after January 2020"
sql_query = sql_pipeline(nl_query)
print("Generated SQL Query:", sql_query)