In [1]:
import langchain
from langchain.llms import CTransformers

config = {'max_new_tokens': 256, 'repetition_penalty': 1.1, 'temperature': 0, 'context_length': 10000}
#https://github.com/marella/ctransformers#config For config of CTransformers

llm = CTransformers(model="TheBloke/CodeLlama-7B-Instruct-GGUF", 
                    model_file="codellama-7b-instruct.Q4_K_M.gguf",config=config, verbose=True)


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

In [2]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

langchain.verbose = True

db = SQLDatabase.from_uri('mysql://dbuser:dbpwd@localhost:3306/sakila',
        #include_tables=['customer', 'address', 'city', 'country'], # include only the tables you want to query. Reduces tokens.
        sample_rows_in_table_info=3
    )

print(db.table_info)




CREATE TABLE actor (
	actor_id SMALLINT UNSIGNED NOT NULL AUTO_INCREMENT, 
	first_name VARCHAR(45) NOT NULL, 
	last_name VARCHAR(45) NOT NULL, 
	last_update TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 
	PRIMARY KEY (actor_id)
)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB DEFAULT CHARSET=utf8mb4

/*
3 rows from actor table:
actor_id	first_name	last_name	last_update
1	PENELOPE	GUINESS	2006-02-15 04:34:33
2	NICK	WAHLBERG	2006-02-15 04:34:33
3	ED	CHASE	2006-02-15 04:34:33
*/


CREATE TABLE address (
	address_id SMALLINT UNSIGNED NOT NULL AUTO_INCREMENT, 
	address VARCHAR(50) NOT NULL, 
	address2 VARCHAR(50), 
	district VARCHAR(20) NOT NULL, 
	city_id SMALLINT UNSIGNED NOT NULL, 
	postal_code VARCHAR(10), 
	phone VARCHAR(20) NOT NULL, 
	last_update TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, 
	PRIMARY KEY (address_id), 
	CONSTRAINT fk_address_city FOREIGN KEY(city_id) REFERENCES city (city_id) ON DELETE RESTRICT ON UPDATE CASCADE
)COL

  self._metadata.reflect(
  for tbl in self._metadata.sorted_tables


In [3]:
import time

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False, use_query_checker=True)

start = time.time()

db_chain.run("How many customers are from district California?")

elapsed_time = time.time() - start
print(f"Time taken to construct query: {elapsed_time}")



[1m> Entering new SQLDatabaseChain chain...[0m
How many customers are from district California?
SQLQuery:

[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mYou are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention 


[1m> Finished chain.[0m

[1m> Finished chain.[0m
Time taken to construct query: 381.04306626319885


In [None]:
from langchain.prompts.prompt import PromptTemplate

examples = [
        {
            "input": "How many customers are from district California?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            WHERE ad.district = 'California';",
            "result": "[(9,)]",
            "answer": "There are 9 customers from California",
        },
        {
            "input": "How many customers are from city San Bernardino?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            JOIN city ci  ON ad.city_id = ci.city_id WHERE ci.city = 'San Bernardino';",
            "result": "[(1,)]",
            "answer": "There is 1 customer from San Bernardino",
        },
        {
            "input": "How many customers are from country United States?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            JOIN city ci ON ad.city_id = ci.city_id JOIN country co ON ci.country_id = co.country_id \
            WHERE co.country = 'United States';",
            "result": "[(36,)]",
            "answer": "There are 36 customers from United States",
        },
]

example_prompt = PromptTemplate(
    input_variables=["input", "sql_cmd", "result", "answer",],
    template="\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {result}\nAnswer: {answer}",
)

#print(example_prompt.format(**examples[2]))

In [None]:
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma


embeddings = HuggingFaceEmbeddings()

to_vectorize = [" ".join(example.values()) for example in examples]

vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples)

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=1,
)

In [None]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt

#print(PROMPT_SUFFIX)

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_mysql_prompt,
    suffix=PROMPT_SUFFIX, 
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

In [None]:
local_chain = SQLDatabaseChain.from_llm(llm, db, prompt=few_shot_prompt, use_query_checker=True, 
                                        verbose=True, return_sql=False,)

In [None]:
start = time.time()

local_chain.run("How many customers are from country Canada?")

elapsed_time = time.time() - start
print(f"Time taken to construct query: {elapsed_time}")
