<a href="https://colab.research.google.com/github/ramahasiba/NLP/blob/langGraph/Build_a_Question_Answering_System_Over_SQL_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# [Build a Question Answering System Over SQL Data](https://python.langchain.com/docs/tutorials/sql_qa/)

In [None]:
!apt-get update && apt-get install -y sqlite3 -q

In [None]:
!curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db

In [None]:
%%capture --no-stderr
%pip install --upgrade --quiet langchain-community langgraph

In [None]:
!pip install dotenv -q
from dotenv import load_dotenv
try:
  load_dotenv('.env')
except ImportError:
  print('No .env file found')

In [None]:
import getpass
import os

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_API_KEY"] = os.environ.get("LANGSMITH_API_KEY")

In [None]:
!pip install -qU "langchain[groq]"
os.environ["GROQ_API_KEY"]=os.environ.get("GROQ_API_KEY")

model_name = "llama3-70b-8192"

from langchain.chat_models import init_chat_model
llm=init_chat_model(model_name, model_provider="groq")

In [None]:
!pip install -qU langchain-huggingface

from langchain_huggingface import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

In [None]:
!pip install -qU langchain-chroma

from langchain_chroma import Chroma

vector_store = Chroma(
    collection_name="example_collection",
    embedding_function=embeddings,
    persist_directory="./chroma_langchain_db",  # Where to save data locally, remove if not necessary
)

In [None]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

In [None]:
from typing_extensions import TypedDict

class State(TypedDict):
  question: str
  query: str
  result: str
  answer: str

In [None]:
from langchain_core.prompts import ChatPromptTemplate

system_message = """
Given an input question, create a syntactically correct {dialect} query to
run to help find the answer. Unless the user specifies in his question a
specific number of examples they wish to obtain, always limit your query to
at most {top_k} results. You can order the results by a relevant column to
return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the
few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema
description. Be careful to not query for columns that do not exist. Also,
pay attention to which column is in which table.

Only use the following tables:
{table_info}
"""

user_prompt = "Question: {input}"

query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)

for message in query_prompt_template.messages:
  message.pretty_print()

In [None]:
from typing_extensions import Annotated, TypedDict
class QueryOutput(TypedDict):
  """Generated SQL query."""
  query: Annotated[str, ..., "Syntactically valid SQL query."]

def write_query(state: State):
  """Generate SQL query to fetch information."""
  prompt = query_prompt_template.invoke(
      {
          "dialect": db.dialect,
          "top_k": 10,
          "table_info": db.get_table_info(),
          "input": state["question"],
      }
  )
  structured_llm = llm.with_structured_output(QueryOutput)
  result = structured_llm.invoke(prompt)
  return {"query": result["query"]}

In [None]:
write_query({"question": "How many Employees are there?"})

In [None]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

def execute_query(state: State):
  """Execute SQL query."""
  execute_query_tool = QuerySQLDatabaseTool(db=db)
  return {"result": execute_query_tool.invoke(state["query"])}

In [None]:
execute_query({"query": "SELECT COUNT(EmployeeId) AS EmployeeCount FROM Employee;"})

In [None]:
def generate_answer(state: State):
  """Answer question using retrieved information as context."""
  prompt = (
      "Given the following user question, corresponding SQL query,"
      "and SQL result, answer the user question. \n\n"
      f"Question: {state['question']}\n"
      f"SQL Query: {state['query']}\n"
      f"SQL Result: {state['result']}\n"
  )
  response = llm.invoke(prompt)
  return {"answer": response.content}

In [None]:
from langgraph.graph import START, StateGraph

graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [None]:
graph

In [None]:
for step in graph.stream(
    {"question": "How many employees are there?"}, stream_mode="updates"
):
  print(step)

In [None]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])

config = {"configurable": {"thread_id": "1"}}

In [None]:
graph

In [None]:
for step in graph.stream(
  {"question": "How amny employees are there?"},
  config,
  stream_mode="updates"
):
  print(step)

try:
  user_approval = input("Do you want to go to execute query? (yes/no):")
except Exception:
  user_approval = "no"

if user_approval.lower() == "yes":
  # If approved, continue the graph ExecutionContext
  for step in graph.stream(None, config, stream_mode="updates"):
    print(step)
else:
  print("Operation cancelled by user.")

In [None]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

tools

In [None]:
system_message = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
    dialect="SQLite",
    top_k=5,
)

In [None]:
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, tools, prompt=system_message)

In [None]:
question = "Which country's customers spent the most?"

for step in agent_executor.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
  step["messages"][-1].pretty_print()

In [None]:
for step in agent_executor.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
  print(f"{step}\n\n----------------\n")

In [None]:
import ast
import re

def query_as_list(db, query):
  res = db.run(query)
  res = [el for sub in ast.literal_eval(res) for el in sub if el]
  res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
  return list(set(res))

artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]

In [None]:
!pip install -qU langchain-chroma

In [None]:
from langchain_chroma import Chroma

vector_store = Chroma(
    collection_name="SQL_Agent",
    embedding_function=embeddings,
    persist_directory="./db",  # Where to save data locally, remove if not necessary
)

In [None]:
from langchain.agents.agent_toolkits import create_retriever_tool

_ = vector_store.add_texts(artists + albums)
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
description = (
    "Use to look up values to filter on. Input is an approximate spelling "
    "of the proper noun, output is valid proper nouns. Use the noun most "
    "similar to the search."
)

retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)

In [None]:
print(retriever_tool.invoke("Alice Chains"))

In [None]:
# Add to system message
suffix = (
    "If you need to filter on a proper noun like a Name, you must ALWAYS first look up "
    "the filter value using the 'search_proper_nouns' tool! Do not try to "
    "guess at the proper name - use this function to find similar ones."
)

system = f"{system_message}\n\n{suffix}"

tools.append(retriever_tool)

agent = create_react_agent(llm, tools, prompt=system)

In [None]:
question = "How many albums does Aisha Duo in chain have?"

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()

In [None]:
print(retriever_tool.invoke("Alis"))

In [None]:
# Get the list of all tables
tables = db.get_usable_table_names()

# Loop through each table and display a few rows
for table_name in tables:
    print(f"\n--- Sample rows from table: {table_name} ---")
    try:
        # Limit to 5 rows for brevity
        query = f"SELECT * FROM {table_name} LIMIT 5;"
        rows = db.run(query)
        print(rows)
    except Exception as e:
        print(f"Could not retrieve data from {table_name}: {e}")

In [None]:
print(db.run("SELECT * FROM Artist;"))

In [None]:
question = "How many artists are there?"
artists_number_response = agent.invoke(
    {"messages": [{"role": "user", "content": question}]}
    )

In [None]:
artists_number_response["messages"][-1]

In [None]:
agent