### Libraries and API Keys:

In [1]:
# Importing the necessary libraries:
import os
import pandas as pd
import numpy as np
import transformers
import langchain
import pymysql
import tiktoken
from langchain.text_splitter import TokenTextSplitter
from langchain import HuggingFaceHub, SQLDatabase, PromptTemplate, FewShotPromptTemplate
from langchain.chains import SQLDatabaseChain
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import RagModel, RagTokenizer, RagRetriever

# Setting the environment for the Hugging face API:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "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"

### Flan-T5-xxl is the only reliable model:

In [2]:
# Prompt Attempt:
model = "google/flan-t5-xxl" #'Sandiago21/llama-7b-hf-prompt-answering'

llm = HuggingFaceHub(repo_id=model, model_kwargs={"temperature":0.5, "max_length":512})

prompt = "Alice has a parrot. What animal is Alice's pet?"
print(llm(prompt))

parrot


### Creating an LLM SQL chain with the winestore database

#### Intention is the to query the SQL DB with the LLM's response.

In [3]:
# Be sure to add the password to the db and the name of the schema
sql_uri = "mysql+pymysql://root:Liklik69&@localhost:3306/wine_store"

In [4]:
db = SQLDatabase.from_uri(sql_uri)

In [5]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

### Zero Shot:

In [6]:
PROMPT = """ 
Given an input question, first create a syntactically correct MySQL query to run,  
then look at the results of the query and return the answer.  
The question: {question}
"""

In [7]:
question = "what is the average price from the table purchasepricesdec?" 

db_chain.run(PROMPT.format(question=question))



[1m> Entering new  chain...[0m
 
Given an input question, first create a syntactically correct MySQL query to run,  
then look at the results of the query and return the answer.  
The question: what is the average price from the table purchasepricesdec?

SQLQuery:[32;1m[1;3mSELECT avg(Price) FROM purchasepricesdec[0m
SQLResult: [33;1m[1;3m[(38.65188805482593,)][0m
Answer:[32;1m[1;3m38.65188805482593[0m
[1m> Finished chain.[0m


'38.65188805482593'

### Few Shot:

In [8]:
# Few Shot:
examples = [
  {"question": "What is the average of the PurchasePrice in the table purchasepricesdec?",
    "answer": "26.48821956"},
  {"question": "How many unique values do we have in the Classification column in the table purchasepricesdec?",
    "answer": "2"},
  {"question": "How many rows do we have in the table purchasepricesdec?",
    "answer":"12262"},
  {"question": "What is the mode of the Volume column in purchasepricesdec table?",
    "answer":"750"}
]

example_prompt = PromptTemplate(input_variables=["question", "answer"], template="Question: {question}\n{answer}")

prompt = FewShotPromptTemplate(
    examples=examples, 
    example_prompt=example_prompt, 
    suffix="Question: {input}", 
    input_variables=["input"]
)

db_chain.run(prompt.format(input="What is the mean of the Volume column in the purchasepricesdec table?"))



[1m> Entering new  chain...[0m
Question: What is the average of the PurchasePrice in the table purchasepricesdec?
26.48821956

Question: How many unique values do we have in the Classification column in the table purchasepricesdec?
2

Question: How many rows do we have in the table purchasepricesdec?
12262

Question: What is the mode of the Volume column in purchasepricesdec table?
750

Question: What is the mean of the Volume column in the purchasepricesdec table?
SQLQuery:[32;1m[1;3mSELECT avg(T1.PurchasePrice) FROM purchase[0m

ProgrammingError: (pymysql.err.ProgrammingError) (1146, "Table 'wine_store.purchase' doesn't exist")
[SQL: SELECT avg(T1.PurchasePrice) FROM purchase]
(Background on this error at: https://sqlalche.me/e/14/f405)

### RAG model:

#### Using WikiSQL dataset for the retrieval queries. Create a vector database (FAISS indexed) of the csv file by embedding the csv and use the similarity search feature to query the LLM.

In [10]:
input_data = pd.read_csv(r"C:\Users\Rutvik\LLM_SQL\train\train_text2sql.csv")
input_data.head()

Unnamed: 0,question,sql
0,Tell me what the notes are for South Australia,SELECT Notes FROM table WHERE Current slogan =...
1,What is the current series where the new serie...,SELECT Current series FROM table WHERE Notes =...
2,What is the format for South Australia?,SELECT Format FROM table WHERE State/territory...
3,Name the background colour for the Australian ...,SELECT Text/background colour FROM table WHERE...
4,how many times is the fuel propulsion is cng?,SELECT COUNT Fleet Series (Quantity) FROM tabl...


In [11]:
# Using only 0.01 of the 54,000 rows of the train dataset as the embeddings is taking a long time.
#input_df = input_data.sample(frac=0.01)
#input_df.to_csv(r"C:\Users\Rutvik\LLM_SQL\train\train_text2sql_small.csv", index=False)

In [12]:
nl_and_sql = CSVLoader(file_path=r"C:\Users\Rutvik\LLM_SQL\train\train_text2sql_small.csv", encoding="utf8")
docs = nl_and_sql.load()
docs

[Document(page_content='question: How many different countries did the champion Se Ri Pak (2) represent?\nsql: SELECT COUNT Country FROM table WHERE Champion = Se Ri Pak (2)', metadata={'source': 'C:\\Users\\Rutvik\\LLM_SQL\\train\\train_text2sql_small.csv', 'row': 0}),
 Document(page_content='question: On what date is Jerilyn Britz the runner-up?\nsql: SELECT Date FROM table WHERE Runner(s)-up = jerilyn britz', metadata={'source': 'C:\\Users\\Rutvik\\LLM_SQL\\train\\train_text2sql_small.csv', 'row': 1}),
 Document(page_content='question: What is the fixed charge for the user who had a tariff of 11.30?\nsql: SELECT Fixed Charge ( Rs. /kWh) FROM table WHERE Tariff ( Rs. /kWh) = 11.30', metadata={'source': 'C:\\Users\\Rutvik\\LLM_SQL\\train\\train_text2sql_small.csv', 'row': 2}),
 Document(page_content='question: How many people voted in Cabarrus county?\nsql: SELECT MIN Total FROM table WHERE County = Cabarrus', metadata={'source': 'C:\\Users\\Rutvik\\LLM_SQL\\train\\train_text2sql_smal

#### Vector Database:

In [9]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')

#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
    'Sentences are passed as a list of string.',
    'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

Sentence: This framework generates embeddings for each input sentence
Embedding: [-1.37173375e-02 -4.28515524e-02 -1.56286061e-02  1.40537517e-02
  3.95537987e-02  1.21796295e-01  2.94333640e-02 -3.17523815e-02
  3.54959480e-02 -7.93140158e-02  1.75878331e-02 -4.04369719e-02
  4.97259796e-02  2.54912823e-02 -7.18700588e-02  8.14968571e-02
  1.47074426e-03  4.79627438e-02 -4.50336039e-02 -9.92174968e-02
 -2.81769708e-02  6.45046085e-02  4.44670543e-02 -4.76217009e-02
 -3.52952927e-02  4.38671596e-02 -5.28566539e-02  4.32992238e-04
  1.01921462e-01  1.64072327e-02  3.26996557e-02 -3.45986672e-02
  1.21339196e-02  7.94871375e-02  4.58344072e-03  1.57778300e-02
 -9.68208164e-03  2.87625827e-02 -5.05806468e-02 -1.55794267e-02
 -2.87907030e-02 -9.62285884e-03  3.15556601e-02  2.27349512e-02
  8.71449560e-02 -3.85027416e-02 -8.84718373e-02 -8.75498727e-03
 -2.12342981e-02  2.08923817e-02 -9.02077556e-02 -5.25732338e-02
 -1.05638532e-02  2.88310889e-02 -1.61455162e-02  6.17843913e-03
 -1.23234

In [17]:
# Split text into chunks of 512 tokens, with 20% token overlap:
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=103)

def num_tokens_from_string(string: str, encoding_name="cl100k_base") -> int:
    """Calculate number of tokens in a text string."""
    if not string:
        return 0
    
    # Returns the number of tokens in a text string
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    
    return num_tokens

In [19]:
# List for smaller chunked text and metadata:
new_list = []

# Create a new list by splitting up text into token sizes of around 512 tokens:
for i in range(len(input_data.index)):
    text = input_data['question'][i]
    token_len = num_tokens_from_string(text)
    if token_len <= 512:
        new_list.append([input_data['question'][i], input_data['sql'][i]])
    else:
        #split text into chunks using text splitter
        split_text = text_splitter.split_text(text)
        for j in range(len(split_text)):
            new_list.append([input_data['question'][i], split_text[j], input_data['sql'][i]])

In [20]:
df_new = pd.DataFrame(new_list, columns=['question', 'sql'])
df_new.head()

Unnamed: 0,question,sql
0,Tell me what the notes are for South Australia,SELECT Notes FROM table WHERE Current slogan =...
1,What is the current series where the new serie...,SELECT Current series FROM table WHERE Notes =...
2,What is the format for South Australia?,SELECT Format FROM table WHERE State/territory...
3,Name the background colour for the Australian ...,SELECT Text/background colour FROM table WHERE...
4,how many times is the fuel propulsion is cng?,SELECT COUNT Fleet Series (Quantity) FROM tabl...


In [21]:
from langchain.document_loaders import DataFrameLoader

# page_content_column is the column name in the dataframe to create embeddings for:
loader = DataFrameLoader(df_new, page_content_column = 'question')
docs = loader.load()

In [None]:
embeddings = HuggingFaceEmbeddings()
retrieval_db = FAISS.from_documents(docs, embeddings)

In [None]:
def retrieve_info(query):
    """Performs similarity search basis the input query and returns the contents as an array."""
    
    simi_search = retrieval_db.similarity_search(query, k=3)
    retrieved_info = [info.page_content for info in simi_search]    
    return retrieved_info

In [None]:
question = """Get the data from purchasepricesdec table where the price is greater than 100?"""

In [None]:
print(retrieve_info(query))

In [None]:
print(retrieve_info(query)[0])

In [None]:
template = """ 
Given a question, first create a syntactically correct MySQL query by referring to the help provided in the sample questions 
and answers, then display the results of the MySQL query.  

The following is the question from the user:
{question}

sample question and sql answers:
{sample}
"""

prompt = PromptTemplate(input_variables=["question", "sample"], template=template)

In [None]:
rag_db_chain = SQLDatabaseChain.from_llm(llm=llm, db=db, verbose=True)

In [None]:
sample_ret = retrieve_info(question)
response = rag_db_chain.run(template.format(question=question, sample=sample_ret))
response