# SQL Chain, based on `SQLDatabaseChain`

## Setup

### Imports

In [None]:
import os
import pickle

import db_connect 
from test_data import TestData

from langsmith import Client

from langchain_community.llms.ollama import Ollama

from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

from langchain_community.llms.gigachat import GigaChat

### LangSmith

In [None]:
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
client = Client()

### Models

In [None]:
gigachat = GigaChat(credentials=os.environ["GIGACHAT_AUTH"], verify_ssl_certs=False, scope="GIGACHAT_API_PERS")

llama3 = Ollama(model="llama3:8b", temperature=0)
llama3_inst = Ollama(model="llama3:instruct", temperature=0)
llama3_inst_q8 = Ollama(model="llama3:8b-instruct-q8_0", temperature=0)
llama3_inst_fp16 = Ollama(model="llama3:8b-instruct-fp16", temperature=0)

### Connect to DB

In [None]:
db = db_connect.get_db()

#### Check connection

In [None]:
db.run("SELECT * FROM passenger")

----
## Testing an agent with different models and save the results

In [None]:
db_chain = SQLDatabaseChain.from_llm(llama3, db, use_query_checker=True)

In [None]:
results = db_chain.batch(TestData.QUESTIONS, return_exceptions=True)

In [None]:
for (i, result) in enumerate(results):
    print(f"Result number {i + 1}:")
    print(result['result'] if type(result) == dict else f"\033[91m{result}\033[0m", end=f"\n{"-" * 250}\n")

In [None]:
with open("..\\data\\db_chain\\llama3_res.pickle", "wb") as f:
    pickle.dump(results, f)