**Reference: https://python.langchain.com/v0.1/docs/use_cases/sql/large_db/**


#### What happens in this notebook:

##### **Table Model Definition**
- **`Table` Class**: This is a simple Pydantic model representing a SQL table. It has one attribute, `name`, which is a string and is described as "Name of table in SQL database."
  - This model is used in the extraction process to match relevant SQL tables based on the user's query.

---

##### **Helper Function - `get_tables`**
- **`get_tables`**: This function takes a list of `Table` objects (i.e., categories such as "Music" or "Business") and returns a list of corresponding SQL table names based on the category.
  - For example, if the category is `"Music"`, the tables `"Album"`, `"Artist"`, `"Genre"`, etc., are added to the result.
  - Similarly, for `"Business"`, the corresponding tables like `"Customer"`, `"Employee"`, etc., are included.

---

##### **LLM and Environment Setup**
- Loads environment variables (`MODEL_NAME`, `TEMPERATURE`, `GEMINI_API_KEY`) using `dotenv`.
- Initializes two LLMs using `ChatGoogleGenerativeAI`:
  - `sql_agent_llm`: for SQL query generation.
  - `table_extractor_llm`: for extracting relevant tables or categories.

---

##### **Connect to the SQL Database (`db`)**
- Connects to the Chinook SQLite database using `SQLDatabase.from_uri`.
- Prints available tables.
- Runs a sample query: `SELECT * FROM Artist LIMIT 10;`

---

##### **Strategies for Table Selection**

##### **Strategy A: Direct Table Extraction**
- Combines all table names from the database.
- Uses LLM and `create_extraction_chain_pydantic` to return relevant SQL table names based on user's question.

##### **Strategy B: Rule-based Category Mapping**
- Uses a predefined mapping:
  - `"Music"` → Album, Artist, Genre, MediaType, Playlist, PlaylistTrack, Track
  - `"Business"` → Customer, Employee, Invoice, InvoiceLine
- LLM decides the category and returns all related tables.

##### **Strategy C: Category → Table Mapping (2-Step)**
- **Step 1**: Uses LLM to classify the question into a category (e.g., Music or Business).
- **Step 2**: Python function `get_tables()` maps the category to actual SQL table names.
---
##### **SQL Query Construction and Cleaning** 
- **`query_chain`**: Generates the SQL query using LLM and `create_sql_query_chain()`.
- **`extract_sql_query()`**: Extracts and cleans raw LLM output to return only the valid SQL string (removes code blocks, `SQLQuery:` prefix, semicolons, etc.).
- Uses `RunnablePassthrough` and `RunnableLambda` to build a pipeline:
  1. Takes user question.
  2. Identifies relevant tables via `table_chain`.
  3. Generates and cleans SQL query.
  4. Returns the final query result from database.

---
##### **Full Agent Pipeline (`full_chain`)** 
- **Input Mapping**: Maps `{"question": itemgetter("input")}` for consistency in the chain.
- **Pipeline Flow**:
  - `table_chain` → assigns `table_names_to_use`
  - `query_chain` → generates SQL using table context
  - `clean_query_output` → extracts final query string

In [None]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from langchain_community.utilities import SQLDatabase
from langchain_google_genai import ChatGoogleGenerativeAI
from pprint import pprint

load_dotenv()

True

**Set the environment variables and load the LLM**

In [None]:
sql_agent_llm = ChatGoogleGenerativeAI(
    model=os.getenv("MODEL_NAME"),
    temperature=os.getenv("TEMPERATURE"),
    google_api_key=os.getenv("GEMINI_API_KEY")
)

table_extractor_llm = ChatGoogleGenerativeAI(
    model=os.getenv("MODEL_NAME"),
    temperature=os.getenv("TEMPERATURE"),
    google_api_key=os.getenv("GEMINI_API_KEY")
)

In [3]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

**Prepare the `Table` class**

In [4]:
from langchain_core.pydantic_v1 import BaseModel, Field

class Table(BaseModel):
    """
    Represents a table in the SQL database.

    Attributes:
        name (str): The name of the table in the SQL database.
    """
    name: str = Field(description="Name of table in SQL database.")


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


### **Strategy A:**

In [5]:
table_names = "\n".join(db.get_usable_table_names())
pprint(table_names)

('Album\n'
 'Artist\n'
 'Customer\n'
 'Employee\n'
 'Genre\n'
 'Invoice\n'
 'InvoiceLine\n'
 'MediaType\n'
 'Playlist\n'
 'PlaylistTrack\n'
 'Track')


In [6]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

  table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)


[Table(name='Artist'), Table(name='Track'), Table(name='Genre')]

### **Strategy B:**

Music:

- "Album"
- "Artist"
- "Genre"
- "MediaType"
- "Playlist"
- "PlaylistTrack"
- "Track"

Business:

- "Customer"
- "Employee"
- "Invoice"
- "InvoiceLine"

In [7]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""You will recieve a question.

If the question is about **Music**, return **ALL** these tables:
  - "Album"
  - "Artist"
  - "Genre"
  - "MediaType"
  - "Playlist"
  - "PlaylistTrack"
  - "Track"

If the question is about **Business**, return **ALL** these tables:
  - "Customer"
  - "Employee"
  - "Invoice"
  - "InvoiceLine"

If you are unsure, return the full list of all available tables for both Music and Business categories."""
table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Album'),
 Table(name='Artist'),
 Table(name='Genre'),
 Table(name='MediaType'),
 Table(name='Playlist'),
 Table(name='PlaylistTrack'),
 Table(name='Track')]

### **Strategy C:**

- **Step 1: Define the category**

In [8]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

In [9]:
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music')]

- **Step 2: Execute the python function**

In [10]:
def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.

    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.

    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables 
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']

### **Final step:**

**Attach the desired strategy to your SQL agent**

In [11]:
import re

def extract_sql_query(text: str) -> str:
    """
    Extract a clean SQL query from a possibly formatted or prefixed LLM output.
    Handles code blocks, 'SQLQuery:' prefix, and various SQL command types.
    """
    text = text.strip()

    # Step 1: Remove code block markers
    if text.startswith("```"):
        lines = text.splitlines()
        # Remove first/last lines if they're ``` or ```sql
        if lines[0].startswith("```") and lines[-1].startswith("```"):
            text = "\n".join(lines[1:-1]).strip()

    # Step 2: Remove known prefixes like "SQLQuery:"
    text = re.sub(r"^(SQLQuery:|Query:)\s*", "", text, flags=re.IGNORECASE)

    # Step 3: Extract only the actual SQL statement (any type)
    match = re.search(
        r"(SELECT|INSERT|UPDATE|DELETE|PRAGMA|CREATE|DROP|ALTER|DESCRIBE|SHOW)\s.+",
        text,
        flags=re.IGNORECASE | re.DOTALL,
    )
    if match:
        query = match.group(0).strip()
        # Remove trailing backticks or semicolons
        return query.rstrip(";`").strip()

    # Step 4: If nothing matched, return original (but cleaned)
    return text


In [12]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.chains import create_sql_query_chain
from operator import itemgetter

query_chain = create_sql_query_chain(sql_agent_llm, db)
clean_query_output = RunnableLambda(extract_sql_query)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain | clean_query_output

**Test the agent**

In [13]:
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

SELECT DISTINCT
  T2.Name
FROM Artist AS T1
INNER JOIN Album AS T3
  ON T1.ArtistId = T3.ArtistId
INNER JOIN Track AS T4
  ON T3.AlbumId = T4.AlbumId
INNER JOIN Genre AS T2
  ON T4.GenreId = T2.GenreId
WHERE
  T1.Name = 'Alanis Morissette'


In [14]:
db.run(query)

"[('Rock',)]"

**Prepare the tool (Don't run the following cell)**

In [None]:
class ChinookSQLAgent:
    """
    SQL agent for interacting with the Chinook database using LLM-generated queries.

    This agent uses a language model to:
    - Identify relevant SQL tables based on the question.
    - Construct a SQL query targeting only those tables.
    - Clean and return the final SQL query string.

    Attributes:
        sql_agent_llm (ChatGoogleGenerativeAI): Configured LLM for query understanding and generation.
        db (SQLDatabase): SQL database connection for the Chinook DB.
        full_chain (Runnable): Execution pipeline for table extraction, query generation, and cleanup.
    """
    
    def __init__(self, sqldb_directory: str, llm: str, llm_temerature: float, llm_api_key: str) -> None:
        """
        Initialize the ChinookSQLAgent with database path and LLM configuration.

        Args:
            sqldb_directory (str): Path to the Chinook SQLite database.
            llm (str): Name of the LLM model (e.g., "gemini-2.5-flash").
            llm_temerature (float): Temperature for LLM response variability.
            llm_api_key (str): API key for the LLM service provider.
        """
        self.sql_agent_llm = ChatGoogleGenerativeAI(
            model=llm,
            temperature=llm_temerature,
            google_api_key=llm_api_key
        )
        self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
        print(self.db.get_usable_table_names())
        category_chain_system = """Return the names of the SQL tables that are relevant to the user question. \
        The tables are:

        Music
        Business"""
        category_chain = create_extraction_chain_pydantic(
            Table, self.sql_agent_llm, system_message=category_chain_system)
        table_chain = category_chain | get_tables  # noqa
        query_chain = create_sql_query_chain(self.sql_agent_llm, self.db)
        clean_query_output = RunnableLambda(extract_sql_query)
        # Convert "question" key to the "input" key expected by current table_chain.
        table_chain = {"input": itemgetter("question")} | table_chain
        # Set table_names_to_use using table_chain.
        self.full_chain = RunnablePassthrough.assign(
            table_names_to_use=table_chain) | query_chain | clean_query_output


@tool
def query_chinook_sqldb(query: str) -> str:
    # """Query the Chinook SQL Database. Input should be a search query."""
    """
    Query the Chinook SQL database using natural language.

    Args:
        query (str): User's natural language question.

    Returns:
        str: Query result after LLM-driven SQL generation and execution.
    """
    # Create an instance of ChinookSQLAgent
    agent = ChinookSQLAgent(
        sqldb_directory=TOOLS_CFG.chinook_sqldb_directory,
        llm=TOOLS_CFG.chinook_sqlagent_llm,
        llm_temerature=TOOLS_CFG.chinook_sqlagent_llm_temperature,
        llm_api_key=TOOLS_CFG.chinook_sqlagent_llm_api_key
    )

    query = agent.full_chain.invoke({"question": query})

    return agent.db.run(query)
