Create a an LLM pipeline that will transform any free text query into a SQL query, the key points of this task are:

* Create a valid representation of SQL tables allowing for semantic search that will match the top results to the given free text query

* Based on the table representation the LLM has to create a real SQL query, based on free text user query, that will allow for immediate usage

* LLM has to support creation of queries with different levels of complexity not only the simplest ones

* LLM has to support creating queries to fetch data from different database schemas

* When an error in LLM created sql query is encountered it should attempt to self correct

In [20]:
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
import sqlparse
import subprocess
import csv

In [21]:
llm = LlamaCpp(model_path="sqlcoder2-GGUF\sqlcoder2.Q8_0.gguf",
               n_batch=512,
               n_ctx=2048,
               n_gpu_layers=30,
               verbose=True)

llama_model_loader: loaded meta data with 19 key-value pairs and 485 tensors from sqlcoder2-GGUF\sqlcoder2.Q8_0.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = starcoder
llama_model_loader: - kv   1:                               general.name str              = StarCoder
llama_model_loader: - kv   2:                   starcoder.context_length u32              = 8192
llama_model_loader: - kv   3:                 starcoder.embedding_length u32              = 6144
llama_model_loader: - kv   4:              starcoder.feed_forward_length u32              = 24576
llama_model_loader: - kv   5:                      starcoder.block_count u32              = 40
llama_model_loader: - kv   6:             starcoder.attention.head_count u32              = 48
llama_model_loader: - kv   7:          starcoder.attention.head_count_kv u32     

llama_model_loader: - kv  11:                      tokenizer.ggml.tokens arr[str,49152]   = ["<|endoftext|>", "<fim_prefix>", "<f...
llama_model_loader: - kv  12:                      tokenizer.ggml.scores arr[f32,49152]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  13:                  tokenizer.ggml.token_type arr[i32,49152]   = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  14:                      tokenizer.ggml.merges arr[str,48891]   = ["Ġ Ġ", "ĠĠ ĠĠ", "ĠĠĠĠ ĠĠ...
llama_model_loader: - kv  15:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  16:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  17:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  18:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  322 tensors
llama_model_loader: - type q8_0:  163 tensors
llm_load_voc

In [23]:
generation_template = """
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
If the question does not match any existing tables or columns, return the word 'error' without generating a SQL query.

DDL statements:
{ddl_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
"""

In [24]:
verification_template = """
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Verify if this SQL query correctly answers the question: {user_question}.
SQL query: {sql_query}
If yes, return the same query. If not return corrected query.

DDL statements:
{ddl_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""

In [25]:
def is_sql_query_valid(sql_query):
    """
    Checks if the SQL query is valid based on the content of the query.
    Returns True if valid, False if an error is detected.
    """
    if "error" in sql_query.lower():
        return False

    return True

In [26]:
generation_prompt = PromptTemplate(template=generation_template, input_variables=["user_question", "ddl_statements"])

In [27]:
verification_prompt = PromptTemplate(template=verification_template, input_variables=["user_question", "sql_query", "ddl_statements"])

In [28]:
def read_ddl_from_file(file_path):
    with open(file_path, 'r') as file:
        ddl_statements = file.read()
    return ddl_statements

In [29]:
def get_sql_query_from_llm(prompt, llm, user_question, ddl_statements):
    llm_chain = prompt | llm
    raw_llm_answer = llm_chain.invoke({"user_question": user_question, "ddl_statements": ddl_statements})
    return raw_llm_answer.strip()
    

In [30]:
def is_query_correct(sql_query):
    """
    Validate and fix a SQL query using sqlfluff. Return True if the query is correct after fixing,
    otherwise return False.
    """
    # Print the original SQL query (for debugging purposes)
    print("Original SQL Query:")
    print(sql_query)
    print("\n")

    # Write the SQL query to a temporary file
    with open("temp_query.sql", "w") as f:
        f.write(sql_query)
    
    # Run sqlfluff fix on the temporary file
    try:
        fix_result = subprocess.run(
            ["sqlfluff", "fix", "temp_query.sql", "--dialect", "ansi"],
            capture_output=True,
            text=True
        )
        
        # Print the fixed SQL query (for debugging purposes)
        with open("temp_query.sql", "r") as f:
            fixed_sql_query = f.read()

        print("Fixed SQL Query:")
        print(fixed_sql_query)
        print("\n")
        
        # Run sqlfluff lint on the fixed file
        lint_result = subprocess.run(
            ["sqlfluff", "lint", "temp_query.sql", "--dialect", "ansi"],
            capture_output=True,
            text=True
        )
        
        if lint_result.returncode == 0:
            # If lint passes, return True
            return True
        else:
            # If lint fails, check for errors and return False
            output_lines = lint_result.stderr.splitlines()
            errors = [line for line in output_lines if "error" in line.lower()]
            
            if errors:
                print("Errors found in the SQL query:")
                print("\n".join(errors))
            
            return False
    
    except Exception as e:
        print(f"Error running sqlfluff: {e}")
        return False

In [31]:
def verify_and_correct_sql(sql_query, user_question, ddl_statements):
    if not is_query_correct(sql_query):
        llm_chain = verification_prompt | llm
        raw_llm_answer = llm_chain.invoke({"user_question": user_question, "sql_query": sql_query, "ddl_statements": ddl_statements})
        return raw_llm_answer.strip()
    else:
        return sql_query

In [32]:
def text_to_sql_pipe(user_question, ddl_file_path="database.sql"):
    # Read the DDL statements from the .sql file
    ddl_statements = read_ddl_from_file(ddl_file_path)
    
    # Generate the initial SQL query
    sql_query = get_sql_query_from_llm(generation_prompt, llm, user_question, ddl_statements)
    print(f"Generated SQL Query:\n{sql_query}\n")
    
    if is_sql_query_valid(sql_query):
    
        # Verify and potentially correct the SQL query
        final_sql_query = verify_and_correct_sql(sql_query, user_question, ddl_statements)
        print(f"Final SQL Query After Verification and Correction:\n{final_sql_query}\n")
        
        return final_sql_query
    
    else:
        return "The query does not match any existing tables. Please check the table names or columns and try again."

In [33]:
def save_queries_to_csv(queries, filename):
    """
    Save the text queries and their corresponding SQL results to a CSV file.
    """
    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file, delimiter=';')
        writer.writerow(["Text Query", "Generated SQL Query"])
        
        for text_query in queries:
            # Generate SQL query from text query
            generated_sql = text_to_sql_pipe(text_query)
            
            # Write the text query and the generated SQL query to the CSV
            writer.writerow([text_query, generated_sql])
            print(f"Saved query: {text_query}\nGenerated SQL:\n{generated_sql}\n")

In [None]:
queries = [
    "Show me the names of employees and their departments.",
    "What is the highest salary?",
    "List all orders that were made in the last 30 days.",
    "What are the total sales grouped by region?",
    "List employees whose salary is greater than $50,000.",
    "Get the names of customers and their lifetime value.",
    "Show products where the sales quantity is above 100.",
    "Which departments are managed by the employee with ID 5?",
    "List all sales representatives working in the 'West' region.",
    "What is the most recent sales report date?",
    "Find the average salary of employees in each department.",
    "List customers who have placed more than 5 orders.",
    "Find the total revenue generated by each product.",
    "Show the department with the highest number of employees.",
    "Get the total number of sales representatives in each region.",
    "Find the average lifetime value of customers grouped by region.",
    "List the top 3 products with the highest revenue.",
    "Find employees who have not received a salary in the last year.",
    "Show the total number of orders per customer.",
    "List all departments with more than one manager."
]


In [34]:
queries = [
    "How many stars are in milky way galaxy?",
    "Show me the names of employees and their departments."
]

# ToDo: Add verification if sql query matches the sql tables

In [35]:
csv_filename = "sql_queries_results_v2.csv"
save_queries_to_csv(queries, csv_filename)


llama_print_timings:        load time =   38047.77 ms
llama_print_timings:      sample time =      12.41 ms /    23 runs   (    0.54 ms per token,  1852.75 tokens per second)
llama_print_timings: prompt eval time =   71463.93 ms /   680 tokens (  105.09 ms per token,     9.52 tokens per second)
llama_print_timings:        eval time =  359145.51 ms /    22 runs   (16324.80 ms per token,     0.06 tokens per second)
llama_print_timings:       total time =  433575.90 ms /   702 tokens


Generated SQL Query:
SELECT COUNT(performance_id::integer) AS number_of_stars FROM analytics.product_performance;

Original SQL Query:
SELECT COUNT(performance_id::integer) AS number_of_stars FROM analytics.product_performance;




KeyboardInterrupt: 