# Enhancing SQL Queries with Few-Shot Learning via CodeLlama and LangChain

## Prerequisites for running this notebook
For running this notebook, we will first need to meet a few prerequisites. You can fork the code from https://github.com/yernenip/CodeLlama-LangChain-MySql

1. First, get your Docker container or MySQL up and running with Sakila DB installed on it. You can refer to the readme in above git repository to learn how to do this.
2. Next, install all the packages required to run this notebook using pip install. These packages are

* ctransformers (base transformers with no GPU acceleration, as I am running this locally on CPU)
* ctransformers[cuda] (if you have cuda support, then this will provide CUDA GPU acceleration, install this instead)
* langchain (core stuff)
* sqlalchemy (database chain uses this under the hood)
* sentence_transformers (possibly needed for hugging face embeddings?)
* chromadb ( vector database)



## Getting the CodeLlama LLM

We are going to use the "Quantized" version of CodeLlama 7B to be more memory optimized as we run it locally on the laptop. The default 7B model would take up 7x4 = 28 GB of memory. When we take the Q4 version, it is 4-bit or should take roughly 3.5 GB to load. However, when we run inference, it will expand to almost double the size. So we end up consuming over 7 GB in RAM.

We are also setting the context_length to 10000, which is really large. I could not find a better way as we are using prompt learning, so all the context (schemas, tables, queries) will need to be passed in the prompt for this to work. Let me know if there is a better way to do this.



In [None]:
import langchain
from langchain.llms import CTransformers

config = {'max_new_tokens': 256, 'repetition_penalty': 1.1, 'temperature': 0, 'context_length': 10000}
#https://github.com/marella/ctransformers#config For config of CTransformers

llm = CTransformers(model="TheBloke/CodeLlama-7B-Instruct-GGUF", 
                    model_file="codellama-7b-instruct.Q4_K_M.gguf",config=config, verbose=True)


### Once the models are downloaded connect to the Database.

LangChain uses SQLAlchemy under the hood to talk to and query database.

In [None]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

langchain.verbose = True

db = SQLDatabase.from_uri('mysql://dbuser:dbpwd@localhost:3306/sakila',
        #include_tables=['customer', 'address', 'city', 'country'], # include only the tables you want to query. Reduces tokens.
        sample_rows_in_table_info=3
    )

print(db.table_info)



### Setup the chain and run inference.

In following cell we run the database chain 'as-is' or with zero shot training. 

In [None]:
import time

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False, use_query_checker=True)

start = time.time()

db_chain.run("How many customers are from district California?")

elapsed_time = time.time() - start
print(f"Time taken to construct and run query: {elapsed_time}")

### Creating an example_prompt and an array of examples. 3 examples shared below

In [None]:
from langchain.prompts.prompt import PromptTemplate

examples = [
        {
            "input": "How many customers are from district California?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            WHERE ad.district = 'California';",
            "result": "[(9,)]",
            "answer": "There are 9 customers from California",
        },
        {
            "input": "How many customers are from city San Bernardino?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            JOIN city ci  ON ad.city_id = ci.city_id WHERE ci.city = 'San Bernardino';",
            "result": "[(1,)]",
            "answer": "There is 1 customer from San Bernardino",
        },
        {
            "input": "How many customers are from country United States?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            JOIN city ci ON ad.city_id = ci.city_id JOIN country co ON ci.country_id = co.country_id \
            WHERE co.country = 'United States';",
            "result": "[(36,)]",
            "answer": "There are 36 customers from United States",
        },
]

example_prompt = PromptTemplate(
    input_variables=["input", "sql_cmd", "result", "answer",],
    template="\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {result}\nAnswer: {answer}",
)

#print(example_prompt.format(**examples[2]))

### Vectorizing the examples shared above and storing them in a local Chroma vector store

In [None]:
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma


embeddings = HuggingFaceEmbeddings()

to_vectorize = [" ".join(example.values()) for example in examples]

vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples)

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=1,
)

### Setting up the Few Shot Prompt which will be passed on to the LLM

In [None]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt

#print(PROMPT_SUFFIX)

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_mysql_prompt,
    suffix=PROMPT_SUFFIX, 
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

### Setup the chain from LLM and run it

In the prompt below I am showing how a complex query can be constructed using JOIN's. The same question does not get a response from CodeLlama when we run it with zero shot as the query is constructed across four tables (customer, address, city and country)

The next prompt after that is the same as the first prompt we did with zero shot. This is just to check if there was any performance improvement. However, it appears that performance degraded a bit, although we got the LLM to use JOIN instead of a sub query.

In [None]:
local_chain = SQLDatabaseChain.from_llm(llm, db, prompt=few_shot_prompt, use_query_checker=True, 
                                        verbose=True, return_sql=False,)

In [None]:
start = time.time()

local_chain.run("How many customers are from country Canada?")

elapsed_time = time.time() - start
print(f"Time taken to construct query: {elapsed_time}")


### Rerunning the first prompt to check performance

In [None]:
start = time.time()

local_chain.run("How many customers are from district California?")

elapsed_time = time.time() - start
print(f"Time taken to construct and run query: {elapsed_time}")