In [7]:
from langchain.llms import CTransformers  # to load the llama2 model
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
import spacy
import json
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import SemanticSimilarityExampleSelector

# Load spaCy model
nlp = spacy.load("en_core_web_sm")

# MySQL connection URI
username = 'root'
password = 'prabal9869'
host = '127.0.0.1'
dbname = 'arl_bank'  # Database name

# Constructing the MySQL URI
mysql_uri = f"mysql+pymysql://{username}:{password}@{host}/{dbname}"

# Initializing SQLDatabase object for MySQL
db = SQLDatabase.from_uri(mysql_uri, sample_rows_in_table_info=3)
print(db.table_info)

# Extract table and column information from the database object
table_info = db.table_info

# Extract column names from the CREATE TABLE statement
# Find the start and end index of the column names section
start_index = table_info.find('(')
end_index = table_info.rfind(')')

# Extract the column names substring
column_names_str = table_info[start_index + 1:end_index]

# Split the column names by comma and remove any leading/trailing whitespace
column_names_list = [col.strip().split()[0] for col in column_names_str.split(',')]

# Assuming the first column is the primary key and should not be included in the dictionary
column_names = {table_info.split()[2]: column_names_list[1:]}  # Assuming table name is after 'CREATE TABLE' and before '('

print(column_names)



CREATE TABLE transactions (
	`Account_No` VARCHAR(50) NOT NULL, 
	`Transaction_details` TEXT, 
	`Withdrawal_amount` INTEGER, 
	`Deposit_amount` INTEGER, 
	`Balance_amount` INTEGER, 
	`Value_date` DATE, 
	`Date` DATE
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from transactions table:
Account_No	Transaction_details	Withdrawal_amount	Deposit_amount	Balance_amount	Value_date	Date
409000611074'	TRF FROM  Indiaforensic SERVICES	0	1000000	1000000	2022-10-05	2022-10-05
409000611074'	TRF FROM  Indiaforensic SERVICES	0	1000000	2000000	2022-10-11	2022-10-11
409000611074'	FDRL/INTERNAL FUND TRANSFE	0	500000	2500000	2022-10-24	2022-10-24
*/
{'transactions': ['`Transaction_details`', '`Withdrawal_amount`', '`Deposit_amount`', '`Balance_amount`', '`Value_date`', '`Date`']}


In [8]:
def extract_columns_from_query(query):
    doc = nlp(query)
    columns = []
    for token in doc:
        if token.pos_ in ["NOUN", "PROPN"]:  # Assuming columns are typically nouns or proper nouns
            columns.append(token.text)
    return set(columns)


In [9]:
# Define few-shot examples
few_shots = [
    {
        'Question': "What is my income in my last 3 months. My Account Number is 409000493201?",
        'SQLQuery': """SELECT SUM(Deposit_amount) AS Total_Income FROM transactions WHERE Account_No = "409000493201" AND Value_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 3 MONTH);""",
        'SQLResult': "16725509",
        'Answer': "16725509 is the income of last 3 months."
    },
    {
        'Question': "What is the total expenses of last 8 months for account number 409000611074?",
        'SQLQuery': """SELECT SUM(Withdrawal_amount) AS Total_Expenses FROM transactions WHERE Account_No = "409000611074" AND Value_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 8 MONTH);""",
        'SQLResult': "75124046",
        'Answer': "75124046 is the total expenses of last 8 months."
    },
    {
        'Question': "How much did I save last month as my account number is 409000493201?",
        'SQLQuery': """SELECT (SUM(Deposit_amount) - SUM(Withdrawal_amount)) AS Savings_Last_Month FROM transactions WHERE Account_No = "409000493201" AND YEAR(Value_date) = YEAR(CURRENT_DATE() - INTERVAL 1 MONTH) AND MONTH(Value_date) = MONTH(CURRENT_DATE() - INTERVAL 1 MONTH);""",
        'SQLResult': "-193509",
        'Answer': "You saved -193509 last month."
    },
    {
        'Question': "How many transactions did I make in last 2 week as my account number is 409000493201?",
        'SQLQuery': """SELECT COUNT(*) AS Total_Transactions FROM transactions WHERE Account_No = "409000493201" AND Value_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 2 WEEK);""",
        'SQLResult': "24",
        'Answer': "You made 24 transactions in last 2 week."
    },
    {
        'Question': "What was my total spending last week as my account number is 409000493201?",
        'SQLQuery': """SELECT SUM(Withdrawal_amount) AS Total_Spending FROM transactions WHERE Account_No = "409000493201" AND Value_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 1 WEEK);""",
        'SQLResult': "515559",
        'Answer': "515559 is your total spending last week."
    },
    {
        'Question': "Give me a breakdown of my expenses for each day this month as my account number is 409000493201?",
        'SQLQuery': """SELECT Date, SUM(Withdrawal_amount) AS Total_Expenses FROM transactions WHERE Account_No = "409000493201" AND YEAR(Value_date) = YEAR(CURRENT_DATE()) AND MONTH(Value_date) = MONTH(CURRENT_DATE()) GROUP BY Date;""",
        'SQLResult': """
Date                        Total_Expenses
2024-06-02                   183648
2024-06-03                   80666
2024-06-04                   135504
2024-06-05                   92031
2024-06-06                   118475
2024-06-07                   16440
2024-06-09                   62112
2024-06-10                   10331
""",
        'Answer': "The total expenses are 183648, 80666, 135504, 92031, 118475, 16440, 62112, 10331 for the year 2024 month of 6 and date 02, 03, 04, 05, 06, 07, 09 and 10."
    },
    {
        'Question': "Will it rain today?",
        'SQLQuery': None,
        'SQLResult': None,
        'Answer': "Sorry, I don't have access to the information."
    }
]

embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

# Convert lists and dictionaries to JSON strings, filtering out None values
to_vectorize = [
    " ".join(
        [json.dumps(value) if isinstance(value, (list, dict)) else (value if value is not None else "") for value in example.values()]
    )
    for example in few_shots
]

# Convert metadata to JSON strings, replacing None with an empty string
metadatas = [
    {k: json.dumps(v) if isinstance(v, (list, dict)) else (v if v is not None else "") for k, v in example.items()}
    for example in few_shots
]

vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=metadatas)

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=1,
)
example_selector.select_examples({"Question": "Do dogs bark?"})


  warn_deprecated(
  from tqdm.autonotebook import tqdm, trange


[{'Answer': "Sorry, I don't have access to the information.",
  'Question': 'Will it rain today?',
  'SQLQuery': '',
  'SQLResult': ''}]

In [10]:
# Define the custom prompt for SQL database interactions
custom_mysql_prompt = """You are an expert in converting natural language into MySQL queries. Follow these instructions carefully to ensure accuracy:
1. Only use the columns specified in the table information provided below.
2. Before proceeding, check if the columns mentioned in the user's natural language query are found in the database. If not, respond with "Sorry, I don't have access to the information."
3. If the columns mentioned in the user's natural language are not found, send the response "Sorry, no information available" and abort/exit without proceeding further.
4. Use the CURDATE() function to get the current date when the question involves "today," "last month," "last year," or "last week."
5. Ensure that your queries are syntactically correct and optimized for performance.
6. Provide clear, accurate, and concise SQL queries without making assumptions beyond the given data.

Table information:
1. `transactions` with columns: `Account_No`, `Date`, `Transaction_details`, `Value_date`, `Withdrawal_amount`, `Deposit_amount`, `Balance_amount`

Use queries for at most {top_k} results.
Remember to only use the provided table columns and structure your queries to handle the specified requests accurately."""

example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult", "Answer"],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}",
)

suffix_eg = "Question: {input}"

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=custom_mysql_prompt,
    suffix=suffix_eg,
    input_variables=["input", "top_k"],  # These variables are used in the prefix and suffix
)


In [24]:
class CustomSQLDatabaseChain(SQLDatabaseChain):
    def __init__(self, llm, db, prompt, verbose=False):
        super().__init__(llm=llm, db=db, prompt=prompt, verbose=verbose)

    def _call(self, inputs):
        query = inputs['input']
        
        # Extract columns from user query
        columns_in_query = extract_columns_from_query(query)
        
        # Validate columns against the schema
        valid_columns = set()
        for table, cols in self.db.table_info.items():  # Use self.db.table_info
            valid_columns.update([col['name'] for col in cols['columns']])
        
        if not columns_in_query.issubset(valid_columns):
            return {"output": "Sorry, information is not available"}
        
        # Proceed with the original SQLDatabaseChain
        return super()._call(inputs)


In [25]:
# Initialize LLM
llm = CTransformers(model="model/llama-2-7b-chat.ggmlv3.q4_0.bin",
                    model_type="llama", config={'max_new_tokens': 3800, 'temperature': 0.5, 'context_length': 3600})

# Initialize the custom chain
custom_chain = CustomSQLDatabaseChain(llm=llm, db=db, prompt=few_shot_prompt, verbose=True)
# Example usage
result = custom_chain({"input": "Will it rain today?"})
print(result['output'])  # Output should be "Sorry, information is not available"

result = custom_chain({"input": "What is my income in my last 3 months. My Account Number is 409000493201?"})
print(result['output'])  # This will process the query if columns are valid




KeyError: 'database'

In [6]:
print(table_info)


CREATE TABLE transactions (
	`Account_No` VARCHAR(50) NOT NULL, 
	`Transaction_details` TEXT, 
	`Withdrawal_amount` INTEGER, 
	`Deposit_amount` INTEGER, 
	`Balance_amount` INTEGER, 
	`Value_date` DATE, 
	`Date` DATE
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from transactions table:
Account_No	Transaction_details	Withdrawal_amount	Deposit_amount	Balance_amount	Value_date	Date
409000611074'	TRF FROM  Indiaforensic SERVICES	0	1000000	1000000	2022-10-05	2022-10-05
409000611074'	TRF FROM  Indiaforensic SERVICES	0	1000000	2000000	2022-10-11	2022-10-11
409000611074'	FDRL/INTERNAL FUND TRANSFE	0	500000	2500000	2022-10-24	2022-10-24
*/
