Create a an LLM pipeline that will transform any free text query into a SQL query, the key points of this task are:

* Create a valid representation of SQL tables allowing for semantic search that will match the top results to the given free text query

* Based on the table representation the LLM has to create a real SQL query, based on free text user query, that will allow for immediate usage

* LLM has to support creation of queries with different levels of complexity not only the simplest ones

* LLM has to support creating queries to fetch data from different database schemas

* When an error in LLM created sql query is encountered it should attempt to self correct

In [1]:
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
import sqlparse
import subprocess
import csv

In [3]:
llm = LlamaCpp(model_path="..\sqlcoder2-GGUF\sqlcoder2.Q8_0.gguf",
               n_batch=512,
               n_ctx=2048,
               n_gpu_layers=30,
               verbose=True)

llama_model_loader: loaded meta data with 19 key-value pairs and 485 tensors from ..\sqlcoder2-GGUF\sqlcoder2.Q8_0.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = starcoder
llama_model_loader: - kv   1:                               general.name str              = StarCoder
llama_model_loader: - kv   2:                   starcoder.context_length u32              = 8192
llama_model_loader: - kv   3:                 starcoder.embedding_length u32              = 6144
llama_model_loader: - kv   4:              starcoder.feed_forward_length u32              = 24576
llama_model_loader: - kv   5:                      starcoder.block_count u32              = 40
llama_model_loader: - kv   6:             starcoder.attention.head_count u32              = 48
llama_model_loader: - kv   7:          starcoder.attention.head_count_kv u32  

In [4]:
generation_template = """
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
If the question does not match any existing tables or columns, return 'error' without generating a SQL query.

DDL statements:
{ddl_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
"""

In [5]:
verification_template = """
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Verify if this SQL query correctly answers the question: {user_question}.
SQL query: {sql_query}
If yes, return the same query. If not return corrected query.

DDL statements:
{ddl_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""

In [6]:
def is_sql_query_valid(sql_query):
    """
    Checks if the SQL query is valid based on the content of the query.
    Returns True if valid, False if an error is detected.
    """
    if "error" in sql_query.lower():
        return False

    return True

In [7]:
generation_prompt = PromptTemplate(template=generation_template, input_variables=["user_question", "ddl_statements"])

In [8]:
verification_prompt = PromptTemplate(template=verification_template, input_variables=["user_question", "sql_query", "ddl_statements"])

In [9]:
def read_ddl_from_file(file_path):
    with open(file_path, 'r') as file:
        ddl_statements = file.read()
    return ddl_statements

In [10]:
def get_sql_query_from_llm(prompt, llm, user_question, ddl_statements):
    llm_chain = prompt | llm
    raw_llm_answer = llm_chain.invoke({"user_question": user_question, "ddl_statements": ddl_statements})
    return raw_llm_answer.strip()
    

In [11]:
def is_query_correct(sql_query):
    """
    Validate and fix a SQL query using sqlfluff. Return True if the query is correct after fixing,
    otherwise return False.
    """
    # Print the original SQL query (for debugging purposes)
    print("Original SQL Query:")
    print(sql_query)
    print("\n")

    # Write the SQL query to a temporary file
    with open("temp_query.sql", "w") as f:
        f.write(sql_query)
    
    # Run sqlfluff fix on the temporary file
    try:
        fix_result = subprocess.run(
            ["sqlfluff", "fix", "temp_query.sql", "--dialect", "ansi"],
            capture_output=True,
            text=True
        )
        
        # Print the fixed SQL query (for debugging purposes)
        with open("temp_query.sql", "r") as f:
            fixed_sql_query = f.read()

        print("Fixed SQL Query:")
        print(fixed_sql_query)
        print("\n")
        
        # Run sqlfluff lint on the fixed file
        lint_result = subprocess.run(
            ["sqlfluff", "lint", "temp_query.sql", "--dialect", "ansi"],
            capture_output=True,
            text=True
        )
        
        if lint_result.returncode == 0:
            # If lint passes, return True
            return True
        else:
            # If lint fails, check for errors and return False
            output_lines = lint_result.stderr.splitlines()
            errors = [line for line in output_lines if "error" in line.lower()]
            
            if errors:
                print("Errors found in the SQL query:")
                print("\n".join(errors))
            
            return False
    
    except Exception as e:
        print(f"Error running sqlfluff: {e}")
        return False

In [12]:
def verify_and_correct_sql(sql_query, user_question, ddl_statements):
    if not is_query_correct(sql_query):
        llm_chain = verification_prompt | llm
        raw_llm_answer = llm_chain.invoke({"user_question": user_question, "sql_query": sql_query, "ddl_statements": ddl_statements})
        return raw_llm_answer.strip()
    else:
        return sql_query

In [17]:
def text_to_sql_pipe(user_question, ddl_file_path="../database.sql"):
    # Read the DDL statements from the .sql file
    ddl_statements = read_ddl_from_file(ddl_file_path)
    
    # Generate the initial SQL query
    sql_query = get_sql_query_from_llm(generation_prompt, llm, user_question, ddl_statements)
    print(f"Generated SQL Query:\n{sql_query}\n")
    
    if is_sql_query_valid(sql_query):
    
        # Verify and potentially correct the SQL query
        final_sql_query = verify_and_correct_sql(sql_query, user_question, ddl_statements)
        print(f"Final SQL Query After Verification and Correction:\n{final_sql_query}\n")
        
        return final_sql_query
    
    else:
        return "The query does not match any existing tables. Please check the table names or columns and try again."

In [14]:
def save_queries_to_csv(queries, filename):
    """
    Save the text queries and their corresponding SQL results to a CSV file.
    """
    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file, delimiter=';')
        writer.writerow(["Text Query", "Generated SQL Query"])
        
        for text_query in queries:
            # Generate SQL query from text query
            generated_sql = text_to_sql_pipe(text_query)
            
            # Write the text query and the generated SQL query to the CSV
            writer.writerow([text_query, generated_sql])
            print(f"Saved query: {text_query}\nGenerated SQL:\n{generated_sql}\n")

In [None]:
queries = [
    "Show me the names of employees and their departments.",
    "What is the highest salary?",
    "List all orders that were made in the last 30 days.",
    "What are the total sales grouped by region?",
    "List employees whose salary is greater than $50,000.",
    "Get the names of customers and their lifetime value.",
    "Show products where the sales quantity is above 100.",
    "Which departments are managed by the employee with ID 5?",
    "List all sales representatives working in the 'West' region.",
    "What is the most recent sales report date?",
    "Find the average salary of employees in each department.",
    "List customers who have placed more than 5 orders.",
    "Find the total revenue generated by each product.",
    "Show the department with the highest number of employees.",
    "Get the total number of sales representatives in each region.",
    "Find the average lifetime value of customers grouped by region.",
    "List the top 3 products with the highest revenue.",
    "Find employees who have not received a salary in the last year.",
    "Show the total number of orders per customer.",
    "List all departments with more than one manager."
]


In [15]:
queries = [
    "How many stars are in milky way galaxy?",
    "Show me the names of employees and their departments."
]

In [18]:
csv_filename = "sql_queries_results_v2.csv"
save_queries_to_csv(queries, csv_filename)


llama_print_timings:        load time =   34570.70 ms
llama_print_timings:      sample time =       5.44 ms /    36 runs   (    0.15 ms per token,  6620.08 tokens per second)
llama_print_timings: prompt eval time =   48608.77 ms /   679 tokens (   71.59 ms per token,    13.97 tokens per second)
llama_print_timings:        eval time =   12112.04 ms /    35 runs   (  346.06 ms per token,     2.89 tokens per second)
llama_print_timings:       total time =   60767.49 ms /   714 tokens


Generated SQL Query:
SELECT COUNT(performance_id::TEXT::INT) AS star_count FROM analytics.product_performance WHERE product_name ilike '%milky%way%galaxy%'

Original SQL Query:
SELECT COUNT(performance_id::TEXT::INT) AS star_count FROM analytics.product_performance WHERE product_name ilike '%milky%way%galaxy%'


Fixed SQL Query:
SELECT COUNT(performance_id::TEXT::INT) AS star_count
FROM analytics.product_performance
WHERE product_name ILIKE '%milky%way%galaxy%'



Final SQL Query After Verification and Correction:
SELECT COUNT(performance_id::TEXT::INT) AS star_count FROM analytics.product_performance WHERE product_name ilike '%milky%way%galaxy%'

Saved query: How many stars are in milky way galaxy?
Generated SQL:
SELECT COUNT(performance_id::TEXT::INT) AS star_count FROM analytics.product_performance WHERE product_name ilike '%milky%way%galaxy%'



Llama.generate: prefix-match hit

llama_print_timings:        load time =   34570.70 ms
llama_print_timings:      sample time =       6.19 ms /    45 runs   (    0.14 ms per token,  7265.10 tokens per second)
llama_print_timings: prompt eval time =   38422.17 ms /   642 tokens (   59.85 ms per token,    16.71 tokens per second)
llama_print_timings:        eval time =   15180.02 ms /    44 runs   (  345.00 ms per token,     2.90 tokens per second)
llama_print_timings:       total time =   53670.92 ms /   686 tokens


Generated SQL Query:
SELECT e.first_name, e.last_name, d.department_name FROM sales.employees AS e JOIN sales.departments AS d ON e.department_id = d.department_id;

Original SQL Query:
SELECT e.first_name, e.last_name, d.department_name FROM sales.employees AS e JOIN sales.departments AS d ON e.department_id = d.department_id;


Fixed SQL Query:
SELECT
    e.first_name,
    e.last_name,
    d.department_name
FROM sales.employees AS e
INNER JOIN sales.departments AS d ON e.department_id = d.department_id;



Final SQL Query After Verification and Correction:
SELECT e.first_name, e.last_name, d.department_name FROM sales.employees AS e JOIN sales.departments AS d ON e.department_id = d.department_id;

Saved query: Show me the names of employees and their departments.
Generated SQL:
SELECT e.first_name, e.last_name, d.department_name FROM sales.employees AS e JOIN sales.departments AS d ON e.department_id = d.department_id;



# ToDo: Add verification if sql query matches the sql tables

In [19]:
import sqlparse
from collections import defaultdict

def parse_ddl_statements(ddl_statements):
    """
    Parses SQL DDL statements and converts them into a dictionary format.
    """
    schema_definitions = defaultdict(dict)
    current_schema = None
    
    # Split the SQL DDL statements into individual statements
    statements = sqlparse.split(ddl_statements)
    
    for statement in statements:
        parsed = sqlparse.parse(statement)[0]
        tokens = [token for token in parsed.tokens if not token.is_whitespace]
        
        # Identify schema creation
        if "CREATE SCHEMA" in statement.upper():
            # Extract schema name
            for token in tokens:
                if isinstance(token, sqlparse.sql.Identifier):
                    current_schema = str(token).strip().strip(';')
                    break
        
        # Identify table creation
        if "CREATE TABLE" in statement.upper():
            if not current_schema:
                raise ValueError("Table defined without schema context.")
            
            # Extract table name
            table_name_token = tokens[2]
            table_name = str(table_name_token).strip().strip(';')
            if '.' not in table_name:
                table_name = f"{current_schema}.{table_name}"
            
            # Extract columns
            columns_section = parsed.token_next_by(i=sqlparse.sql.Parenthesis)
            if columns_section:
                # Extract the content within the parentheses
                columns_content = columns_section[1] if isinstance(columns_section, tuple) else columns_section
                column_tokens = [token for token in columns_content.tokens if isinstance(token, sqlparse.sql.Identifier)]
                columns = [str(token).strip().strip(';') for token in column_tokens]
                schema_definitions[current_schema][table_name] = columns
    
    print(schema_definitions)
    return dict(schema_definitions)


def extract_tables_and_columns(sql_query):
    parsed_query = sqlparse.parse(sql_query)
    tokens = parsed_query[0].tokens
    
    tables = []
    columns = []
    
    # Iterate through the tokens to extract tables and columns
    for token in tokens:
        if isinstance(token, sqlparse.sql.Identifier):
            # Extract table names, including schema if available
            if token.get_parent_name():
                tables.append(token.get_parent_name() + '.' + token.get_real_name())
            else:
                tables.append(token.get_real_name())
                
            # Extract columns
            if token.get_alias():
                columns.append(token.get_alias())
    
    print("tables: ", tables)
    print("columns: ", tables)
    return tables, columns

def verify_sql_query(sql_query, schema_definitions):
    tables, columns = extract_tables_and_columns(sql_query)
    
    errors = []
    
    # Check if tables exist in schema definitions
    for table in tables:
        schema_table_found = False
        table_name = table.split('.')[-1]  # Get the table name without schema
        schema_name = table.split('.')[0] if '.' in table else None
        
        # Check if table exists in any schema
        if schema_name:
            if schema_table_found := (f"{schema_name}.{table_name}" in schema_definitions.get(schema_name, {})):
                pass
        else:
            # If no schema specified, check all schemas
            schema_table_found = any(f"{schema}.{table_name}" in tables for schema in schema_definitions)
        
        if not schema_table_found:
            errors.append(f"Table '{table}' does not exist in any schema.")
    
    # Check if columns exist in the identified tables
    for table in tables:
        table_name = table.split('.')[-1]
        schema_name = table.split('.')[0] if '.' in table else None
        
        # Determine columns in the table
        if schema_name:
            table_columns = schema_definitions.get(schema_name, {}).get(f"{schema_name}.{table_name}", [])
        else:
            # Find in all schemas if no schema was specified
            table_columns = []
            for schema in schema_definitions:
                if f"{schema}.{table_name}" in schema_definitions[schema]:
                    table_columns = schema_definitions[schema][f"{schema}.{table_name}"]
                    break
        
        # Validate columns
        for column in columns:
            if column not in table_columns:
                errors.append(f"Column '{column}' does not exist in table '{table}'.")
    
    if errors:
        return False, errors
    else:
        return True, "SQL query is valid."



In [20]:
# Example usage
ddl_statements = """
-- Create schemas
CREATE SCHEMA public;
CREATE SCHEMA sales;
CREATE SCHEMA analytics;

-- Create tables in the public schema
CREATE TABLE public.employees (
    employee_id INTEGER PRIMARY KEY,
    first_name VARCHAR,
    last_name VARCHAR,
    age INTEGER,
    department_id INTEGER,
    hire_date DATE,
    FOREIGN KEY (department_id) REFERENCES public.departments(department_id)
);

CREATE TABLE public.departments (
    department_id INTEGER PRIMARY KEY,
    department_name VARCHAR,
    manager_id INTEGER,
    FOREIGN KEY (manager_id) REFERENCES public.employees(employee_id)
);

CREATE TABLE public.salaries (
    employee_id INTEGER PRIMARY KEY,
    salary_amount DECIMAL,
    effective_date DATE,
    FOREIGN KEY (employee_id) REFERENCES public.employees(employee_id)
);

-- Create tables in the sales schema
CREATE TABLE sales.orders (
    order_id INTEGER PRIMARY KEY,
    order_date DATE,
    customer_id INTEGER,
    sales_rep_id INTEGER,
    FOREIGN KEY (customer_id) REFERENCES sales.customers(customer_id),
    FOREIGN KEY (sales_rep_id) REFERENCES sales.sales_reps(sales_rep_id)
);

CREATE TABLE sales.customers (
    customer_id INTEGER PRIMARY KEY,
    customer_name VARCHAR,
    contact_number VARCHAR
);

CREATE TABLE sales.sales_reps (
    sales_rep_id INTEGER PRIMARY KEY,
    first_name VARCHAR,
    last_name VARCHAR,
    region VARCHAR
);

CREATE TABLE sales.products (
    product_id INTEGER PRIMARY KEY,
    product_name VARCHAR,
    price DECIMAL
);

-- Create tables in the analytics schema
CREATE TABLE analytics.sales_reports (
    report_id INTEGER PRIMARY KEY,
    report_date DATE,
    total_sales DECIMAL,
    region VARCHAR
);

CREATE TABLE analytics.customer_metrics (
    metric_id INTEGER PRIMARY KEY,
    customer_id INTEGER,
    lifetime_value DECIMAL,
    average_order_value DECIMAL,
    FOREIGN KEY (customer_id) REFERENCES sales.customers(customer_id)
);

CREATE TABLE analytics.product_performance (
    performance_id INTEGER PRIMARY KEY,
    product_id INTEGER,
    sales_quantity INTEGER,
    revenue_generated DECIMAL,
    FOREIGN KEY (product_id) REFERENCES sales.products(product_id)
);
"""

In [21]:
# Parse the DDL statements to get schema definitions
schema_definitions = parse_ddl_statements(ddl_statements)

# Example SQL query to verify
sql_query = "SELECT e.first_name, e.last_name FROM employees e WHERE e.hire_date >= '2020-01-01';"
is_valid, result = verify_sql_query(sql_query, schema_definitions)
print("SQL Query is valid.") if is_valid else print("SQL Query has errors:\n" + "\n".join(result))


defaultdict(<class 'dict'>, {'analytics': {'analytics.TABLE': ['report_id', 'total_sales', 'region'], 'public.departments': ['department_id', 'manager_id', 'public.employees(employee_id)'], 'public.salaries': ['employee_id', 'effective_date', 'public.employees(employee_id)'], 'sales.customers': ['customer_id', 'contact_number'], 'sales.sales_reps': ['sales_rep_id', 'last_name', 'region'], 'sales.products': ['product_id', 'price'], 'analytics.customer_metrics': ['metric_id', 'lifetime_value', 'average_order_value', 'sales.customers(customer_id)'], 'analytics.product_performance': ['performance_id', 'sales_quantity', 'revenue_generated', 'sales.products(product_id)']}})
tables:  ['employees']
columns:  ['employees']
SQL Query has errors:
Table 'employees' does not exist in any schema.
Column 'e' does not exist in table 'employees'.


In [None]:
import sqlparse
import json
from collections import defaultdict

def parse_ddl_to_json(ddl_statements):
    """
    Parses SQL DDL statements and converts them into a JSON format.
    """
    schema_definitions = defaultdict(lambda: defaultdict(list))
    current_schema = None
    
    # Split the SQL DDL statements into individual statements
    statements = sqlparse.split(ddl_statements)
    
    for statement in statements:
        # Parse each statement
        parsed = sqlparse.parse(statement)[0]
        tokens = [token for token in parsed.tokens if not token.is_whitespace]
        
        # Identify schema creation
        if "CREATE SCHEMA" in statement.upper():
            current_schema = str(tokens[-1]).strip().strip(';')
        
        # Identify table creation
        if "CREATE TABLE" in statement.upper():
            if not current_schema:
                raise ValueError("Table defined without schema context.")
            
            table_name = str(tokens[2]).strip().strip(';')
            if '.' not in table_name:
                table_name = f"{current_schema}.{table_name}"
            
            # Extract columns from the parentheses
            columns_section = parsed.token_next_by(i=sqlparse.sql.Parenthesis)
            if columns_section:
                columns_section = columns_section[1]  # The parenthesis contains the columns
                column_tokens = [token for token in columns_section.tokens if isinstance(token, sqlparse.sql.Identifier)]
                columns = [str(column.get_real_name()).strip() for column in column_tokens]
                
                schema_definitions[current_schema][table_name] = columns
    
    # Convert to JSON format
    schema_json = json.dumps(schema_definitions, indent=4)
    return schema_json

In [None]:
schema_json = parse_ddl_to_json(ddl_statements)
print(schema_json)

ValueError: Table defined without schema context.