<a href="https://colab.research.google.com/github/shahnbej/Langchain/blob/main/Fully_Local_SQL_Agent_with_Llama_3_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. Configure Langsmith
2. Import LLM
3. Import Data
4. Dynamic few-shot prompt
5. Custom SQL Tools
6. ReAct Agent Executor
7. Persistent Memory
8. Showcase in Gradio UI

Langsmith Configuration

In [1]:
%%capture
!pip install python-dotenv langchain_community langchain_ollama langchain_huggingface

In [2]:

from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv(), override=True)

False

In [3]:
import os
langsmith_api_key = os.environ.get("LANGSMITH_API_KEY")


os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Local SQL Agent"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key

LLM

In [4]:
from langchain_ollama import ChatOllama

llm = ChatOllama(model="llama3.1:8b-instruct-q4_0")  #"llama3.1"

Database

In [5]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///chinook.db", sample_rows_in_table_info = 3)

In [6]:
print(db.table_info)




Few Shot Examples

In [7]:
examples = [
    {   "input": "List all artists.",
        "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {
        "input": "List all tracks in the 'Rock' genre.",
        "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
    },
    {
        "input": "Find the total duration of all tracks.",
        "query": "SELECT SUM(Milliseconds) FROM Track;",
    },
    {
        "input": "List all customers from Canada.",
        "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
    },
    {
        "input": "How many tracks are there in the album with ID 5?",
        "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
    },
    {
        "input": "Find the total number of Albums.",
        "query": "SELECT COUNT(DISTINT(AlbumId)) FROM Invoice;",
    },
    {
        "input": "List all tracks that are longer than 5 minutes.",
        "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
    },
    {
        "input": "Who are the top 5 customers by total purchase?",
        "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    },
    {
        "input": "How many employees are there",
        "query": 'SELECT COUNT(*) FROM "Employee"',
    },
]
print(len(examples))

10


Dynamic Example Selector

In [8]:
from langchain_huggingface import HuggingFaceEmbeddings

In [9]:
embeddings = HuggingFaceEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2')

  from tqdm.autonotebook import tqdm, trange


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [11]:
!pip install faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [12]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    FAISS,
    k=2,
    input_keys=["input"],
    )

example_selector.vectorstore.search("How many arists are there?", search_type = "mmr")

[Document(metadata={'input': 'How many tracks are there in the album with ID 5?', 'query': 'SELECT COUNT(*) FROM Track WHERE AlbumId = 5;'}, page_content='How many tracks are there in the album with ID 5?'),
 Document(metadata={'input': 'How many employees are there', 'query': 'SELECT COUNT(*) FROM "Employee"'}, page_content='How many employees are there'),
 Document(metadata={'input': 'Who are the top 5 customers by total purchase?', 'query': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;'}, page_content='Who are the top 5 customers by total purchase?'),
 Document(metadata={'input': 'Find the total duration of all tracks.', 'query': 'SELECT SUM(Milliseconds) FROM Track;'}, page_content='Find the total duration of all tracks.')]

Prompt

In [13]:
system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

You have access to the following tools for interacting with the database:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of {tool_names}
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.
If you see you are repeating yourself, just provide final answer and exit.

Here are some examples of user inputs and their corresponding SQL queries:"""

In [14]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

dynamic_few_shot_prompt = FewShotPromptTemplate(
    example_selector = example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input"],
    prefix=system_prefix,
    suffix=""
)

In [15]:
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate

full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=dynamic_few_shot_prompt),
        ("human", "{input}"),
        ("system", "{agent_scratchpad}"),
    ]
)

Custom Tools

In [16]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool, InfoSQLDatabaseTool, ListSQLDatabaseTool, QuerySQLCheckerTool

tools = [QuerySQLDataBaseTool(db = db), InfoSQLDatabaseTool(db = db), ListSQLDatabaseTool(db = db), QuerySQLCheckerTool(db = db, llm = llm)]
print(QuerySQLDataBaseTool(db = db).description)

prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there?",
        "tool_names" : [tool.name for tool in tools],
        "tools" : [tool.name + " - " + tool.description.strip() for tool in tools],
        "agent_scratchpad": [],
    }
)

print(prompt_val.to_string())


    Execute a SQL query against the database and get back the result..
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    
System: You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

You have access to the following tools for interacting with the database:

['sql_db_query - Execute a SQL query against the database and get back the result..\n    If the query is not correct, an error message will be returned.\n    

Agent Executor

In [17]:
from langchain.agents import AgentExecutor, create_react_agent
agent = create_react_agent(llm, tools, full_prompt)

agent_executor = AgentExecutor(agent=agent, tools=tools, handle_parsing_errors=True)

History Management

In [18]:
last_k_messages = 4


from langchain_community.chat_message_histories import SQLChatMessageHistory

def get_session_history(session_id):
    chat_message_history = SQLChatMessageHistory(
    session_id=session_id, connection = "sqlite:///memory.db", table_name = "local_table"
    )

    messages = chat_message_history.get_messages()
    chat_message_history.clear()

    for message in messages[-last_k_messages:]:
        chat_message_history.add_message(message)

    print("chat_message_history ", chat_message_history)
    return chat_message_history


from langchain_core.runnables.history import RunnableWithMessageHistory

agent_with_chat_history = RunnableWithMessageHistory(
    agent_executor,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
)

Gradio UI

In [20]:
!pip install gradio

Collecting gradio
  Downloading gradio-4.40.0-py3-none-any.whl.metadata (15 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.111.1-py3-none-any.whl.metadata (26 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.2.0 (from gradio)
  Downloading gradio_client-1.2.0-py3-none-any.whl.metadata (7.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.9 (from gradio)
  Downloading python_multipart-0.0.9-py3-none-any.whl.metadata (2.5 kB)
Collecting ruff>=0.2.2 (from gradio)
  Downloading ruff-0.5.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (24 kB)
Collecting semantic-version~=2.0 (from gradio)
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting tomlkit==0.12.0 (

In [1]:
import gradio as gr
import uuid


with gr.Blocks() as demo:

    state = gr.State("")
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot])


    def respond(message, chatbot_history, session_id):
        if not chatbot_history:
            session_id = uuid.uuid4().hex

        print("Session ID: ", session_id)

        response = agent_with_chat_history.invoke(
                                        {"input": message},
                                        {"configurable": {"session_id": session_id}},
                                        )

        chatbot_history.append((message, response['output']))
        return "", chatbot_history, session_id

    msg.submit(respond, [msg, chatbot, state], [msg, chatbot, state])

#demo.launch(debug=True)
demo.launch()


ModuleNotFoundError: No module named 'gradio'