In [81]:
from dotenv import load_dotenv
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from urllib.parse import quote_plus
from langchain.llms import HuggingFaceHub
import os

from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq

In [82]:
load_dotenv()

True

In [83]:
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)

## Connecting to database
- when variable contain special characters like '@' we need to URL encode it

In [84]:
def init_database(user: str, password: str, host: str, port: str, database: str) -> SQLDatabase:
  # URL-encode the password
  password = quote_plus(password)
  db_uri = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}"
  return SQLDatabase.from_uri(db_uri)

## Creating Prompt Template

In [85]:
def get_sql_chain(db):
    template = """
    You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
    Based on the table schema below, write a SQL query that would answer the user's question. Take the conversation history into account.
    
    <SCHEMA>{schema}</SCHEMA>
    
    Conversation History: {chat_history}
    
    Write only the SQL query and nothing else. Do not wrap the SQL query in any other text, not even backticks.
    
    For example:
    Question: which 3 artists have the most tracks?
    SQL Query: SELECT ArtistId, COUNT(*) as track_count FROM Track GROUP BY ArtistId ORDER BY track_count DESC LIMIT 3;
    Question: Name 10 artists
    SQL Query: SELECT Name FROM Artist LIMIT 10;
    
    Your turn:
    
    Question: {question}
    SQL Query:
    """
    
    prompt = ChatPromptTemplate.from_template(template)
    
    # llm = ChatOpenAI()
    # llm = HuggingFaceHub(repo_id="meta-llama/Meta-Llama-3-8B-Instruct", model_kwargs={"temperature":0.5, "max_length":512})
    llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
    
    def get_schema(_):
        return db.get_table_info()
    
    return (
        RunnablePassthrough.assign(schema=get_schema)
        | prompt
        | llm
        | StrOutputParser()
    )

## validating and refining queries

In [86]:
def validate_sql_query(query: str):
    if query.strip().lower().startswith("select"):
        return True
    return False


## setting Up LLM

In [87]:
def get_response(user_query: str, db: SQLDatabase, chat_history: list):
  sql_chain = get_sql_chain(db)
  
  template = """
    You are a data analyst at a company. You are interacting with a user who is asking you questions about the company's database.
    Based on the table schema below, question, sql query, and sql response, write a natural language response.
    <SCHEMA>{schema}</SCHEMA>

    Conversation History: {chat_history}
    SQL Query: <SQL>{query}</SQL>
    User question: {question}
    SQL Response: {response}"""
  
  prompt = ChatPromptTemplate.from_template(template)
  
  # llm = ChatOpenAI(model="gpt-4-0125-preview")
  # llm = HuggingFaceHub(repo_id="meta-llama/Meta-Llama-3-8B-Instruct", model_kwargs={"temperature":0.5, "max_length":512})
  llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
  
  chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
      schema=lambda _: db.get_table_info(),
      response=lambda vars: db.run(vars["query"]),
    )
    | prompt
    | llm
    | StrOutputParser()
  )
  
  return chain.invoke({
    "question": user_query,
    "chat_history": chat_history,
  })
    

## Integrating

In [88]:
# Database credentials
user = "root"
password = "Sarthak@14"
host = "localhost"
port = '3306'
name = "clv"


In [89]:

# Connect to the database
db = init_database(user, password, host, port, name)
print("Connected to database!")

Connected to database!


In [90]:
chat_history = [
    AIMessage(content="Hello! I'm a SQL assistant. Ask me anything about your database."),
]

# Print chat history
for message in chat_history:
    if isinstance(message, AIMessage):
        print(f"AI: {message.content}")
    elif isinstance(message, HumanMessage):
        print(f"Human: {message.content}")

while True:

    user_query = input("Type a message...")
    if user_query is not None and user_query.strip() != "":
        chat_history.append(HumanMessage(content=user_query))
        print(f"Human: {user_query}")

    response = get_response(user_query, db, chat_history)
    print(f"AI: {response}")

    chat_history.append(AIMessage(content=response))


AI: Hello! I'm a SQL assistant. Ask me anything about your database.
Human: how many states are there
AI: Hello! Based on the query results, there are 4 distinct states represented in the data table. These states are Washington, Arizona, Nevada, and California.
Human: how many rows are in the dataset
AI: Based on the SQL response, there are 9134 rows in the dataset.
Human: name all rows in the dataaset
