# SQL agent

## Setup

### Imports

In [19]:
import os
import pickle

import db_connect
from test_data import TestData

from langsmith import Client

from langchain_core.language_models import BaseLanguageModel

from langchain.prompts import PromptTemplate

from langchain.agents.agent_types import AgentType
from langchain.agents.agent_toolkits.sql import base
from langchain.agents.agent import AgentExecutor

from langchain_community.llms.ollama import Ollama
from langchain_community.embeddings.ollama import OllamaEmbeddings

from langchain_community.llms.gigachat import GigaChat

### LangSmith

In [2]:
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
client = Client()

### Load models

In [3]:
llama3 = Ollama(model="llama3:8b", temperature=0.1)
llama3_embeddings = OllamaEmbeddings(model="llama3:8b", temperature=0.1)

llama3_inst = Ollama(model="llama3:instruct", temperature=0.1)
llama3_inst_q8 = Ollama(model="llama3:8b-instruct-q8_0", temperature=0.1)
llama3_inst_fp16 = Ollama(model="llama3:8b-instruct-fp16", temperature=0.1)

gigachat = GigaChat(
    model="GigaChat",
    temperature=0.1,
    credentials=os.environ["GIGACHAT_AUTH"], 
    verify_ssl_certs=False, 
    scope="GIGACHAT_API_PERS",
)

### Connect to DB with Readonly role

In [4]:
db = db_connect.get_db()

#### Check connection

In [5]:
db.run("select * from passenger")

"[(1, 'John'), (2, 'James'), (3, 'Poul'), (4, 'Christofer'), (5, 'Superman'), (6, 'Donald'), (7, 'Douglas'), (8, 'Dwight'), (9, 'Earl'), (10, 'Edgar'), (11, 'Edmund'), (12, 'Edwin'), (13, 'Elliot'), (14, 'Eric'), (15, 'Ernest'), (16, 'Ethan'), (17, 'Ezekiel'), (18, 'Felix'), (19, 'Franklin'), (20, 'Frederick'), (21, 'Gabriel'), (22, 'Joseph'), (23, 'Joshua'), (24, 'Julian'), (25, 'Alice'), (26, 'Bob'), (27, 'Charlie'), (28, 'David'), (29, 'Emily'), (30, 'Frank'), (31, 'George'), (32, 'Helen'), (33, 'Irene'), (34, 'Jack'), (35, 'Kate'), (36, 'Leo'), (37, 'Mary'), (38, 'Nancy'), (39, 'Oliver'), (40, 'Paul'), (41, 'Qiana'), (42, 'Robert'), (43, 'Samantha'), (44, 'Thomas'), (45, 'Victoria')]"

-----

## Create agent

### Prompt and error handler

In [6]:
template =\
"""# Personality
You are an expert in PostgreSQL. Answer the question of an ordinary employee who does not know SQL. Please reply in plain text.
# Task
Please answer the following question to the best of your ability.
# Tools
Use the following tools to answer:
{tools}
# Format
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action... 
(this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

IF YOU WROTE \"FINAL ANSWER\". THEN DON\'T WRITE ANYTHING MORE
# Start chain of thoughts

Question: {input}
Thought:{agent_scratchpad}"""

prompt = PromptTemplate.from_template(template)

def err_handle(err):
    return "# Error\n" +\
        f"The previous steps caused an error, here is its text: \n\"{err}\"\n" +\
        "DO NOT OUTPUT AN ACTION AND A FINAL ANSWER AT THE SAME TIME"

### Function to get agent with provided llm model

In [7]:
def get_agent(llm: BaseLanguageModel) -> AgentExecutor:
    """Create and return agent with provided `llm` model"""
    
    return base.create_sql_agent(
        llm=llm,
        db=db,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        prompt=prompt,
        agent_executor_kwargs={
            "handle_parsing_errors": err_handle,
        }
    )

-----
## Testing an agent with different models and save the results

In [14]:
results = get_agent(llama3).batch(TestData.QUESTIONS)

In [15]:
for result in results:
    print(result)

{'input': 'write a joke', 'output': 'Agent stopped due to iteration limit or time limit.'}
{'input': 'write a joke about cats', 'output': 'Why did the cat take a selfie? To capture its purr-fect side!'}


In [22]:
with open("..\\data\\llama3_res.pickle", "wb") as f:
    pickle.dump(results, f)