# SQL agent

## Setup

### Imports

In [1]:
import os
import pickle

import db_connect
from test_data import TestData

from langsmith import Client

from langchain_core.language_models import BaseLanguageModel

from langchain.prompts import PromptTemplate

from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits.sql import base
from langchain.agents.agent import AgentExecutor
from langchain.agents import create_react_agent
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings
from langchain_experimental.llms.ollama_functions import OllamaFunctions

from prompt_generator import PromptGenerator
from examples import FEW_SHOT_EXAMPLES

from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

### LangSmith

In [2]:
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
client = Client()

### Load models

In [3]:
llama3_embeddings = OllamaEmbeddings(model="llama3:8b", temperature=0)

llama3 = Ollama(model="llama3:instruct", temperature=0)
llama3_q8 = Ollama(model="llama3:8b-instruct-q8_0", temperature=0)

llama3__with_functions = OllamaFunctions(model="llama3:instruct", format="json")
llama3_q8__with_functions = OllamaFunctions(model="llama3:8b-instruct-q8_0", format="json")

chatgpt__with_functions = ChatOpenAI(base_url="https://api.vsegpt.ru/v1/", model="openai/gpt-3.5-turbo-0125")
chatgpt_embeddings = OpenAIEmbeddings(base_url="https://api.vsegpt.ru/v1/", model="openai/gpt-3.5-turbo-0125")

### Connect to DB with Readonly role

In [4]:
db = db_connect.get_db()

#### Check connection

In [5]:
db.run("select * from passenger")

"[(1, 'John'), (2, 'James'), (3, 'Poul'), (4, 'Christofer'), (5, 'Superman'), (6, 'Donald'), (7, 'Douglas'), (8, 'Dwight'), (9, 'Earl'), (10, 'Edgar'), (11, 'Edmund'), (12, 'Edwin'), (13, 'Elliot'), (14, 'Eric'), (15, 'Ernest'), (16, 'Ethan'), (17, 'Ezekiel'), (18, 'Felix'), (19, 'Franklin'), (20, 'Frederick'), (21, 'Gabriel'), (22, 'Joseph'), (23, 'Joshua'), (24, 'Julian'), (25, 'Alice'), (26, 'Bob'), (27, 'Charlie'), (28, 'David'), (29, 'Emily'), (30, 'Frank'), (31, 'George'), (32, 'Helen'), (33, 'Irene'), (34, 'Jack'), (35, 'Kate'), (36, 'Leo'), (37, 'Mary'), (38, 'Nancy'), (39, 'Oliver'), (40, 'Paul'), (41, 'Qiana'), (42, 'Robert'), (43, 'Samantha'), (44, 'Thomas'), (45, 'Victoria')]"

-----

## Create agent

### Prompt and error handler

In [6]:
prompt_generator = PromptGenerator().set_example_selector(
        examples=FEW_SHOT_EXAMPLES,
        embedding_llm=llama3_embeddings,
        k=3
)

### Function to get agent with provided llm model

In [7]:
def get_agent(llm: BaseLanguageModel) -> AgentExecutor:
    """Create and return agent with provided `llm` model"""
    
    return base.create_sql_agent(
        llm=llm,
        db=db,
        agent_type="openai-tools",
        prompt=prompt_generator.get_prompt(),
    )

-----
## Testing an agent with different models and save the results

In [11]:
print(TestData.ANSWER[0])

Here are the names you asked for: 'Alice', 'Bob', 'Charlie', 'Christofer', 'David'


In [8]:
res = get_agent(chatgpt__with_functions).invoke({
    "dialect": "PostgreSQL",
    "input": TestData.QUESTIONS[0],
    "top_k": "3"
})

In [10]:
print(res["output"])

The first 5 names of the passengers, sorted alphabetically, are:
1. Alice
2. Bob
3. Charlie
4. Christofer
5. David


In [12]:
res = get_agent(llama3__with_functions).invoke({
    "dialect": "PostgreSQL",
    "input": TestData.QUESTIONS[0],
    "top_k": "3"
})

In [13]:
print(res["output"])

{ "query": "SELECT  \"name\" FROM public.\"passenger\" ORDER BY  \"name\" LIMIT 5;" }

 

 
 

 
 

 
 

 
 

 

 
 

 
 

 

 

 
 

 
 

 

 

 

 

 

 

 

 

 

 
 




In [14]:
res = get_agent(llama3_q8__with_functions).invoke({
    "dialect": "PostgreSQL",
    "input": TestData.QUESTIONS[0],
    "top_k": "3"
})

In [15]:
print(res["output"])

{"Query": "SELECT passenger.name FROM public.passenger ORDER BY passenger.name LIMIT 5;"}

  
 
