In [1]:
import os 
from dotenv import load_dotenv

load_dotenv()

GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
os.environ['GOOGLE_API_KEY'] = GOOGLE_API_KEY

In [2]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""

prompt = ChatPromptTemplate.from_template(template)

In [3]:
prompt.format(schema="my schema", question="how many users are there?")

"Human: Based on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there?\nSQL Query:"

In [4]:
from langchain_community.utilities import SQLDatabase

sqlite_uri = 'sqlite:///./chinook.db' 

db = SQLDatabase.from_uri(sqlite_uri)


In [5]:
db.run("SELECT * FROM Album LIMIT 5")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

In [6]:
def get_schema(_):
    return db.get_table_info()

In [18]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_google_genai import ChatGoogleGenerativeAI
import re

llm = ChatGoogleGenerativeAI(model="gemini-pro")

def extract_sql(text):
    """Extrai o código SQL de uma string com 'code fences'."""
    match = re.search(r"```sql\s*(.*?)\s*```", text, re.DOTALL)
    if match:
        return match.group(1)
    return text  # Retorna o texto original se não encontrar code fences

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
    | RunnableLambda(extract_sql)
)

In [19]:
sql_chain.invoke({"question": "how many artists are there?"})

'SELECT COUNT(*) AS "Number of Artists"\nFROM Artist;'

In [20]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)


In [21]:
def run_query(query):
    return db.run(query)


In [22]:
run_query('SELECT COUNT(*) AS "Number of Artists"\nFROM Artist;')

'[(275,)]'

In [25]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm
    | StrOutputParser()
)

In [26]:
full_chain.invoke({"question": "how many artists are there?"})

'There are 275 artists in the database.'