In [1]:
from langchain.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.chains import LLMChain
import re
import pandas as pd
import torch
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve, accuracy_score, confusion_matrix, auc, recall_score, precision_score
import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC

In [2]:
llm = LlamaCpp(model_path="sqlcoder2-GGUF\sqlcoder2.Q4_K_M.gguf",
               n_batch=512,
               n_ctx=2048,
               n_gpu_layers=30,
               verbose=True)

llama_model_loader: loaded meta data with 19 key-value pairs and 485 tensors from sqlcoder2-GGUF\sqlcoder2.Q4_K_M.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = starcoder
llama_model_loader: - kv   1:                               general.name str              = StarCoder
llama_model_loader: - kv   2:                   starcoder.context_length u32              = 8192
llama_model_loader: - kv   3:                 starcoder.embedding_length u32              = 6144
llama_model_loader: - kv   4:              starcoder.feed_forward_length u32              = 24576
llama_model_loader: - kv   5:                      starcoder.block_count u32              = 40
llama_model_loader: - kv   6:             starcoder.attention.head_count u32              = 48
llama_model_loader: - kv   7:          starcoder.attention.head_count_kv u32   

In [4]:
ddl_statements = """
-- Create schemas
CREATE SCHEMA public;
CREATE SCHEMA sales;
CREATE SCHEMA analytics;

-- Create tables in the public schema
CREATE TABLE public.employees (
    employee_id INTEGER PRIMARY KEY,
    first_name VARCHAR,
    last_name VARCHAR,
    department_id INTEGER,
    hire_date DATE,
    FOREIGN KEY (department_id) REFERENCES public.departments(department_id)
);

CREATE TABLE public.departments (
    department_id INTEGER PRIMARY KEY,
    department_name VARCHAR,
    manager_id INTEGER,
    FOREIGN KEY (manager_id) REFERENCES public.employees(employee_id)
);

CREATE TABLE public.salaries (
    employee_id INTEGER PRIMARY KEY,
    salary_amount DECIMAL,
    effective_date DATE,
    FOREIGN KEY (employee_id) REFERENCES public.employees(employee_id)
);

-- Create tables in the sales schema
CREATE TABLE sales.orders (
    order_id INTEGER PRIMARY KEY,
    order_date DATE,
    customer_id INTEGER,
    sales_rep_id INTEGER,
    FOREIGN KEY (customer_id) REFERENCES sales.customers(customer_id),
    FOREIGN KEY (sales_rep_id) REFERENCES sales.sales_reps(sales_rep_id)
);

CREATE TABLE sales.customers (
    customer_id INTEGER PRIMARY KEY,
    customer_name VARCHAR,
    contact_number VARCHAR
);

CREATE TABLE sales.sales_reps (
    sales_rep_id INTEGER PRIMARY KEY,
    first_name VARCHAR,
    last_name VARCHAR,
    region VARCHAR
);

CREATE TABLE sales.products (
    product_id INTEGER PRIMARY KEY,
    product_name VARCHAR,
    price DECIMAL
);

-- Create tables in the analytics schema
CREATE TABLE analytics.sales_reports (
    report_id INTEGER PRIMARY KEY,
    report_date DATE,
    total_sales DECIMAL,
    region VARCHAR
);

CREATE TABLE analytics.customer_metrics (
    metric_id INTEGER PRIMARY KEY,
    customer_id INTEGER,
    lifetime_value DECIMAL,
    average_order_value DECIMAL,
    FOREIGN KEY (customer_id) REFERENCES sales.customers(customer_id)
);

CREATE TABLE analytics.product_performance (
    performance_id INTEGER PRIMARY KEY,
    product_id INTEGER,
    sales_quantity INTEGER,
    revenue_generated DECIMAL,
    FOREIGN KEY (product_id) REFERENCES sales.products(product_id)
);
"""

In [5]:
template = """
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
{instructions}

DDL statements:
{ddl_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""

In [6]:
prompt = PromptTemplate(template=template, input_variables=["user_question", "instructions", "ddl_statements"])

In [7]:
llm_chain = prompt | llm

In [8]:
user_question = "Select the name and age of employees who joined after January 2020"
instructions = ""

In [9]:
raw_llm_answer = llm_chain.invoke({"user_question": user_question, "instructions": instructions, "ddl_statements": ddl_statements})

print(raw_llm_answer)


llama_print_timings:        load time =    6592.01 ms
llama_print_timings:      sample time =       8.48 ms /    60 runs   (    0.14 ms per token,  7077.14 tokens per second)
llama_print_timings: prompt eval time =    7463.90 ms /   664 tokens (   11.24 ms per token,    88.96 tokens per second)
llama_print_timings:        eval time =   11196.30 ms /    59 runs   (  189.77 ms per token,     5.27 tokens per second)
llama_print_timings:       total time =   18741.14 ms /   723 tokens


SELECT e.first_name, e.last_name, EXTRACT(YEAR FROM AGE(CURRENT_DATE, e.hire_date)) AS employee_age FROM employees e WHERE e.hire_date > '2020-01-01';


In [None]:
def generate_sql_query(natural_language_query):
    try:
        sql_query = llm.generate(natural_language_query)
        return sql_query
    except Exception as e:
        print(f"Error generating SQL query: {e}")
        return None

def sql_pipeline(natural_language_query):
    sql_query = generate_sql_query(natural_language_query)
    return sql_query

# Example usage
nl_query = "Select the name and age of employees who joined after January 2020"
sql_query = sql_pipeline(nl_query)
print("Generated SQL Query:", sql_query)