In [None]:
%pip install -U huggingface_hub smolagents langchain python-dotenv sqlalchemy --upgrade -q

In [None]:
import os
from smolagents import CodeAgent, HfApiModel
from dotenv import load_dotenv
from smolagents.tools import Tool
from smolagents import tool, LiteLLMModel
from sqlalchemy import (
    create_engine,
    inspect,
    text,
)


  from .autonotebook import tqdm as notebook_tqdm


In [22]:
class SQLQueryTool(Tool):
    name = "sql_query_tool"
    description = "Tool to query SQL database containing CSV data"
    inputs = {
        "query": {
            "type": "string",
            "description": "SQL query to execute or a natural language question about the database"
        }
    }
    output_type = "string"
    
    def __init__(self, db_path):
        self.db_path = db_path
        self.connection = None
        self.connect_to_db()
        super().__init__()
        
    def connect_to_db(self):
        """Connect to the SQLite database."""
        import sqlite3
        try:
            self.connection = sqlite3.connect(self.db_path)
            print(f"Connected to database at {self.db_path}")
        except sqlite3.Error as e:
            print(f"Error connecting to database: {e}")
            
    def execute_query(self, query):
        """Execute a SQL query and return the results."""
        if not self.connection:
            self.connect_to_db()
            
        try:
            cursor = self.connection.cursor()
            cursor.execute(query)
            columns = [description[0] for description in cursor.description] if cursor.description else []
            results = cursor.fetchall()
            return {"columns": columns, "data": results}
        except Exception as e:
            return {"error": str(e)}
    
    def forward(self, query: str) -> str:
        """Process the user input and execute the appropriate SQL query.
        
        Args:
            query: SQL query to execute or a natural language question about the database
            
        Returns:
            Formatted results of the query or an error message
        """
        if "SELECT" not in query.upper() and "FROM" not in query.upper():
            # If the input doesn't look like a SQL query, try to interpret it
            tables = self.get_tables()
            if tables:
                return f"Available tables: {', '.join(tables)}. Please specify a SQL query."
            else:
                return "Please provide a valid SQL query."
        
        result = self.execute_query(query)
        if "error" in result:
            return f"Error executing query: {result['error']}"
        
        # Format the results nicely
        formatted_result = "Results:\n"
        if result["columns"]:
            formatted_result += " | ".join(result["columns"]) + "\n"
            formatted_result += "-" * (sum(len(col) for col in result["columns"]) + 3 * (len(result["columns"]) - 1)) + "\n"
            
            for row in result["data"]:
                formatted_result += " | ".join(str(cell) for cell in row) + "\n"
        else:
            formatted_result += "Query executed successfully but returned no data."
            
        return formatted_result
    
    def get_tables(self):
        """Get a list of all tables in the database."""
        if not self.connection:
            self.connect_to_db()
            
        try:
            cursor = self.connection.cursor()
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
            tables = [table[0] for table in cursor.fetchall()]
            return tables
        except Exception as e:
            print(f"Error getting tables: {e}")
            return []
    
    def close(self):
        """Close the database connection."""
        if self.connection:
            self.connection.close()
            self.connection = None

In [9]:
load_dotenv()
hf_api_key = os.getenv("HUGGINGFACEHUB_API_KEY")
gemini_api_key = os.getenv("GEMINI_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")

In [4]:
engine = create_engine("sqlite:///./Database/hotel_database.db")

In [5]:
inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("booking")]

table_description = "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
print(table_description)

Columns:
  - index: INTEGER
  - hotel: TEXT
  - is_canceled: INTEGER
  - lead_time: INTEGER
  - arrival_date_year: INTEGER
  - arrival_date_month: TEXT
  - arrival_date_week_number: INTEGER
  - arrival_date_day_of_month: INTEGER
  - stays_in_weekend_nights: INTEGER
  - stays_in_week_nights: INTEGER
  - adults: INTEGER
  - children: REAL
  - babies: INTEGER
  - meal: TEXT
  - country: TEXT
  - market_segment: TEXT
  - distribution_channel: TEXT
  - is_repeated_guest: INTEGER
  - previous_cancellations: INTEGER
  - previous_bookings_not_canceled: INTEGER
  - reserved_room_type: TEXT
  - assigned_room_type: TEXT
  - booking_changes: INTEGER
  - deposit_type: TEXT
  - agent: REAL
  - company: REAL
  - days_in_waiting_list: INTEGER
  - customer_type: TEXT
  - adr: REAL
  - required_car_parking_spaces: INTEGER
  - total_of_special_requests: INTEGER
  - reservation_status: TEXT
  - reservation_status_date: TEXT


In [6]:
@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.
    The table is named 'booking'. Its description is as follows:
        Columns:
        - index: INTEGER
        - hotel: TEXT
        - is_canceled: INTEGER
        - lead_time: INTEGER
        - arrival_date_year: INTEGER
        - arrival_date_month: TEXT
        - arrival_date_week_number: INTEGER
        - arrival_date_day_of_month: INTEGER
        - stays_in_weekend_nights: INTEGER
        - stays_in_week_nights: INTEGER
        - adults: INTEGER
        - children: REAL
        - babies: INTEGER
        - meal: TEXT
        - country: TEXT
        - market_segment: TEXT
        - distribution_channel: TEXT
        - is_repeated_guest: INTEGER
        - previous_cancellations: INTEGER
        - previous_bookings_not_canceled: INTEGER
        - reserved_room_type: TEXT
        - assigned_room_type: TEXT
        - booking_changes: INTEGER
        - deposit_type: TEXT
        - agent: REAL
        - company: REAL
        - days_in_waiting_list: INTEGER
        - customer_type: TEXT
        - adr: REAL
        - required_car_parking_spaces: INTEGER
        - total_of_special_requests: INTEGER
        - reservation_status: TEXT
        - reservation_status_date: TEXT

    Args:
        query: The query to perform. This should be correct SQL.
    """
    output = ""
    with engine.connect() as con:
        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output

In [13]:
model = LiteLLMModel(
    model_id="groq/qwen-2.5-32b", # "gemini-2.0-flash-exp",  # Specify the Gemini model ID
    api_key=groq_api_key  # Use your API key from environment variables
)

In [14]:
agent = CodeAgent(
    tools=[sql_engine],
    model=model
)
agent.run("Show me total revenue for July 2017.")

# HfApiModel(model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
# token=hf_api_key)

3132959.07

In [23]:
sql_tool = SQLQueryTool(db_path='./Database/hotel_database.db')
model = HfApiModel()
agents = CodeAgent(tools=[sql_tool], model=model)

Connected to database at ./Database/hotel_database.db


In [24]:
agents.run("Show me total revenue for July 2017")

3132959.07