In [1]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
!pip install -U google-generativeai



In [3]:
# --- Load Environment Variables ---
# Ensure your .env file exists in the same directory as this script

import os
from dotenv import load_dotenv
import decimal # New import for Decimal type handling

load_dotenv()

True

In [4]:
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
db_user =  os.getenv("db_user")
db_password =  os.getenv("db_password")
db_host =  os.getenv("db_host")
db_port =  os.getenv("db_port")
db_name =  os.getenv("db_name")

In [5]:
import google.generativeai as genai

# --- Configure Gemini API ---
# This line configures the Gemini API with your API key.
# It's crucial for authenticating your requests to Google's generative models.
genai.configure(api_key=GOOGLE_API_KEY)

In [6]:
# --- Initialize SQLDatabase ---
# This connects to your MySQL database and extracts its schema information.
# The 'sample_rows_in_table_info=3' helps Gemini understand the data types and typical values.

from langchain.utilities import SQLDatabase

print("--- Connecting to MySQL Database and Fetching Schema ---")

try:
    db = SQLDatabase.from_uri(
        f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}",
        sample_rows_in_table_info=3
    )
    database_schema = db.table_info
    print("Database Schema:\n", database_schema)
    print("---------------------------------------------------\n")
except Exception as e:
    print(f"Error connecting to database or fetching schema: {e}")
    print("Please ensure MySQL is running, your .env variables are correct, and pymysql is installed.")
    exit() # Exit if database connection fails, as further steps depend on it.


--- Connecting to MySQL Database and Fetching Schema ---
Database Schema:
 
CREATE TABLE discounts (
	discount_id INTEGER NOT NULL COMMENT 'Unique ID for each discount entry' AUTO_INCREMENT, 
	t_shirt_id INTEGER NOT NULL COMMENT 'Foreign key referencing the t-shirt being discounted', 
	pct_discount DECIMAL(5, 2) COMMENT 'Discount percentage (0-100)', 
	PRIMARY KEY (discount_id), 
	CONSTRAINT discounts_ibfk_1 FOREIGN KEY(t_shirt_id) REFERENCES t_shirts (t_shirt_id), 
	CONSTRAINT discounts_chk_1 CHECK ((`pct_discount` between 0 and 100))
)COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB COMMENT='Table storing discount information for t-shirts' DEFAULT CHARSET=utf8mb4

/*
3 rows from discounts table:
discount_id	t_shirt_id	pct_discount
1	1	10.00
2	2	15.00
3	3	20.00
*/


CREATE TABLE t_shirts (
	t_shirt_id INTEGER NOT NULL COMMENT 'Unique ID for each t-shirt entry' AUTO_INCREMENT, 
	brand ENUM('Van Huesen','Levi','Nike','Adidas') NOT NULL COMMENT 'Brand of the t-shirt', 
	color ENUM('Red','Blue','

In [7]:
# --- List Available Gemini Models ---
# It's good practice to list models to confirm which ones support 'generateContent'
# in your region and for your API key.
print("Available models supporting 'generateContent':")
available_models = []
for m in genai.list_models():
    if 'generateContent' in m.supported_generation_methods:
        available_models.append(m.name)
        print(f"  - {m.name}")

if not available_models:
    print("\nNo models found that support 'generateContent'. Please check your API key and region availability.")
    exit()

Available models supporting 'generateContent':
  - models/gemini-1.0-pro-vision-latest
  - models/gemini-pro-vision
  - models/gemini-1.5-pro-latest
  - models/gemini-1.5-pro-002
  - models/gemini-1.5-pro
  - models/gemini-1.5-flash-latest
  - models/gemini-1.5-flash
  - models/gemini-1.5-flash-002
  - models/gemini-1.5-flash-8b
  - models/gemini-1.5-flash-8b-001
  - models/gemini-1.5-flash-8b-latest
  - models/gemini-2.5-pro-preview-03-25
  - models/gemini-2.5-flash-preview-04-17
  - models/gemini-2.5-flash-preview-05-20
  - models/gemini-2.5-flash
  - models/gemini-2.5-flash-preview-04-17-thinking
  - models/gemini-2.5-flash-lite-preview-06-17
  - models/gemini-2.5-pro-preview-05-06
  - models/gemini-2.5-pro-preview-06-05
  - models/gemini-2.5-pro
  - models/gemini-2.0-flash-exp
  - models/gemini-2.0-flash
  - models/gemini-2.0-flash-001
  - models/gemini-2.0-flash-exp-image-generation
  - models/gemini-2.0-flash-lite-001
  - models/gemini-2.0-flash-lite
  - models/gemini-2.0-flash-p

In [8]:
# --- Initialize Gemini Model ---
# We'll use 'gemini-1.5-flash' as it's often a good balance of speed and capability for free tier.
# You can change this to another model name from the 'Available models' list if needed.
model_name_to_use = 'gemini-1.5-flash' 

In [9]:
# Verify the chosen model is in the available list
if f"models/{model_name_to_use}" not in available_models and model_name_to_use not in available_models:
    print(f"\nWarning: The chosen model '{model_name_to_use}' is not in the list of available models.")
    print("Attempting to use it anyway, but it might result in an error.")
    print("Consider changing 'model_name_to_use' to one from the list above.")
else:
    model = genai.GenerativeModel(model_name_to_use)
    print("Model initialised 😄\nGood to go !!!")

Model initialised 😄
Good to go !!!


In [10]:
# --- Example 1: Basic Text Generation with Gemini ---
# This shows how Gemini responds to a general question.
prompt_general = "Who is the Prime minister of India?"
print(f"\n--- Asking Gemini (General Question): '{prompt_general}' ---")
try:
    response_general = model.generate_content(prompt_general)
    print(response_general.text.strip())
except Exception as e:
    print(f"Error during general text generation: {e}")
print("---------------------------------------------------\n")


--- Asking Gemini (General Question): 'Who is the Prime minister of India?' ---
The current Prime Minister of India is Narendra Modi.
---------------------------------------------------



In [11]:
import pymysql

# ================================================================
# --- NEW FUNCTION FOR SQL GENERATION AND EXECUTION ---
# ================================================================

def ask_sql_query(user_question: str, gemini_model, sql_database_obj):
    """
    Generates a MySQL query from a natural language question using Gemini,
    executes it against the provided database, and returns the results.

    Args:
        user_question (str): The natural language question from the user.
        gemini_model: The initialized Gemini GenerativeModel instance.
        sql_database_obj: The initialized langchain.utilities.SQLDatabase object.

    Returns:
        list: A list of tuples containing the query results.
              Returns 0 if it's an aggregate query yielding no results (SUM/COUNT=NULL or empty).
              Returns None if an error occurs.
    """
    print(f"\n--- Processing user question: '{user_question}' ---")

    # Construct the prompt for Gemini
    sql_generation_prompt = f"""
Given the following MySQL database schema:

{sql_database_obj.table_info}

Generate a valid MySQL query for the following question:
"{user_question}"

When generating the query, ensure that:
1. Column names from the schema are used exactly as provided.
2. String comparisons (like 'Nike', 'extra small', 'white') should match the case and format
   of data as it is actually stored in the database. If exact matches are needed,
   use '='. If a case-insensitive match might be better, consider functions like LOWER().
3. For questions asking "how many" or "total", generate a query that returns a single aggregate value (e.g., SUM, COUNT).
   If no matching rows are found, the query should logically yield 0 or NULL for that aggregate.

Provide only the SQL query, without any additional explanation or text.
"""

    # print("--- Prompt sent to Gemini ---")
    # print(sql_generation_prompt)
    # print("---------------------------\n")

    generated_sql = None
    try:
        response_sql = gemini_model.generate_content(sql_generation_prompt)
        generated_sql = response_sql.text.strip()

        # Remove Markdown code block delimiters
        if generated_sql.startswith("```sql"):
            generated_sql = generated_sql[len("```sql"):].strip()
        if generated_sql.endswith("```"):
            generated_sql = generated_sql[:-len("```")].strip()
            
        print("--- Generated SQL Query (cleaned) ---")
        print(generated_sql)
        print("-----------------------------------\n")

    except Exception as e:
        print(f"Error during SQL generation by Gemini: {e}")
        return None # Indicate failure

    # Execute the generated SQL query
    print("--- Executing SQL Query and Fetching Results ---")
    try:
        conn = pymysql.connect(
            host=db_host,
            user=db_user,
            password=db_password,
            database=db_name,
            port=int(db_port)
        )
        cursor = conn.cursor()
        cursor.execute(generated_sql)
        results = cursor.fetchall()
        
        if cursor.description:
            columns = [i[0] for i in cursor.description]
            print("Columns:", columns)
        
        # --- ENHANCED HANDLING FOR EMPTY/NULL/SINGLE AGGREGATE RESULTS ---
        if not results or (len(results) == 1 and results[0][0] is None):
            print("Query Results: 0 (or no matching data)")
            return 0 # Return 0 for aggregate queries with no results
        elif len(results) == 1 and len(results[0]) == 1 and isinstance(results[0][0], (int, float, decimal.Decimal)):
            # If it's a single aggregate value (like COUNT(*), SUM), return the scalar directly
            print(f"Query Result: {results[0][0]}")
            return results[0][0]
        else:
            print("Query Results:")
            for row in results:
                print(row)
            return results # Return the actual results (list of tuples for multi-row/col)

    except pymysql.Error as db_error:
        print(f"Error executing SQL query: {db_error}")
        return None # Indicate failure

    finally:
        if 'conn' in locals() and conn.open: # Ensure connection is closed even on error
            conn.close()

In [12]:
# This is the core logic for SpeakSQL project.
# Test Query 1: Your original example
q1 = "How many tshirts do we have left for Nike in extra small size and red colour?"
result1 = ask_sql_query(q1, model, db)
print(f"\nFinal Answer for '{q1}': {result1}")

print("\n---------------------------------------------------\n")


--- Processing user question: 'How many tshirts do we have left for Nike in extra small size and red colour?' ---
--- Generated SQL Query (cleaned) ---
SELECT
  COUNT(*)
FROM t_shirts
WHERE
  brand = 'Nike' AND size = 'XS' AND color = 'Red';
-----------------------------------

--- Executing SQL Query and Fetching Results ---
Columns: ['COUNT(*)']
Query Result: 1

Final Answer for 'How many tshirts do we have left for Nike in extra small size and red colour?': 1

---------------------------------------------------



In [18]:
q2 = "If we have to sell all the Levi's tshirt today. How much revenue will our store generate?"
result2 = ask_sql_query(q2, model, db)
print(f"\nFinal Answer for '{q2}': {result2}")


--- Processing user question: 'If we have to sell all the Levi's tshirt today. How much revenue will our store generate?' ---
--- Generated SQL Query (cleaned) ---
SELECT
  SUM(t_shirts.price)
FROM t_shirts
WHERE
  t_shirts.brand = 'Levi';
-----------------------------------

--- Executing SQL Query and Fetching Results ---
Columns: ['SUM(t_shirts.price)']
Query Result: 484

Final Answer for 'If we have to sell all the Levi's tshirt today. How much revenue will our store generate?': 484


In [14]:
q3 = "Show me the total number of products in stock."
result3 = ask_sql_query(q3, model, db)
print(f"\nFinal Answer for '{q3}': {result3}")


--- Processing user question: 'Show me the total number of products in stock.' ---
--- Generated SQL Query (cleaned) ---
SELECT SUM(stock_quantity) FROM t_shirts;
-----------------------------------

--- Executing SQL Query and Fetching Results ---
Columns: ['SUM(stock_quantity)']
Query Result: 3129

Final Answer for 'Show me the total number of products in stock.': 3129


In [15]:
q4 = "How much is the price of inventory for all small sized tshirts?"
result4 = ask_sql_query(q4, model, db)
print(f"\nFinal Answer for '{q4}': {result4}")


--- Processing user question: 'How much is the price of inventory for all small sized tshirts?' ---
--- Generated SQL Query (cleaned) ---
SELECT SUM(price * stock_quantity)
FROM t_shirts
WHERE size = 'S';
-----------------------------------

--- Executing SQL Query and Fetching Results ---
Columns: ['SUM(price * stock_quantity)']
Query Result: 13924

Final Answer for 'How much is the price of inventory for all small sized tshirts?': 13924


In [16]:
q5 = "If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)"
# q5 = """SELECT SUM(t.stock_quantity * IFNULL(t.price * (1 - d.pct_discount / 100), t.price)) as Total_Revenue
#         FROM t_shirts AS t
#         LEFT JOIN discounts AS d
#           ON t.t_shirt_id = d.t_shirt_id
#         WHERE t.brand = 'Levi';"""

# OR 

# q5 = """SELECT SUM(a.Total_Amount * ((100- COALESCE(discounts.pct_discount, 0))/100)) as Total_Revenue
#         FROM (SELECT SUM(price*stock_quantity) as Total_Amount, t_shirt_id from t_shirts
#             	WHERE brand = 'Levi'
#             	GROUP BY t_shirt_id) a
#         LEFT JOIN discounts 
#         ON a.t_shirt_id = discounts.t_shirt_id"""

result5 = ask_sql_query(q5, model, db)
print(f"\nFinal Answer for '{q5}': {result5}")


--- Processing user question: 'If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)' ---
--- Generated SQL Query (cleaned) ---
SELECT
  SUM(IFNULL((
    t_shirts.price * (
      1 - (
        discounts.pct_discount / 100
      )
    )
  ), t_shirts.price) * t_shirts.stock_quantity)
FROM t_shirts
LEFT JOIN discounts
  ON t_shirts.t_shirt_id = discounts.t_shirt_id
WHERE
  t_shirts.brand = 'Levi';
-----------------------------------

--- Executing SQL Query and Fetching Results ---
Columns: ['SUM(IFNULL((\n    t_shirts.price * (\n      1 - (\n        discounts.pct_discount / 100\n      )\n    )\n  ), t_shirts.price) * t_shirts.stock_quantity)']
Query Result: 30972.450000

Final Answer for 'If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)': 30972.450000


In [17]:
q6 = "How many White coloured Levi's tshirts do we have in each size?"
result6 = ask_sql_query(q6, model, db)
print(f"\nFinal Answer for '{q6}': {result6}")


--- Processing user question: 'How many White coloured Levi's tshirts do we have in each size?' ---
--- Generated SQL Query (cleaned) ---
SELECT
  COUNT(*)
FROM t_shirts
WHERE
  brand = 'Levi' AND color = 'White'
GROUP BY
  size;
-----------------------------------

--- Executing SQL Query and Fetching Results ---
Columns: ['COUNT(*)']
Query Results:
(1,)
(1,)
(1,)
(1,)
(1,)

Final Answer for 'How many White coloured Levi's tshirts do we have in each size?': ((1,), (1,), (1,), (1,), (1,))


## Few Shot Learning

In [19]:
few_shots = [
    {
        "Question":"How many tshirts do we have left for Nike in extra small size and red colour?",
        "SQLQuery":"""SELECT SUM(stock_quantity)
                      FROM t_shirts
                      WHERE brand = 'Nike' AND size = 'XS' AND color = 'Red';""",
        "SQLResult":"Result of the SQL Query",
        "Answer":result1
    },
    {
        "Question":"How much is the price of inventory for all small sized tshirts?",
        "SQLQuery":"""SELECT SUM(price * stock_quantity)
                      FROM t_shirts
                      WHERE size = 'S';""",
        "SQLResult":"Result of the SQL Query",
        "Answer":result4
    },
    {
        "Question":"If we have to sell all the Levi's tshirt with discounts applied. How much revenue will our store generate? (post discounts)",
        "SQLQuery":"""SELECT SUM(a.Total_Amount * ((100- COALESCE(discounts.pct_discount, 0))/100)) as Total_Revenue
                      FROM (SELECT SUM(price*stock_quantity) as Total_Amount, t_shirt_id from t_shirts
                          	WHERE brand = 'Levi'
                          	GROUP BY t_shirt_id) a
                      LEFT JOIN discounts 
                      ON a.t_shirt_id = discounts.t_shirt_id""",
        "SQLResult":"Result of the SQL Query",
        "Answer":
    },
    {
        "Question":"If we have to sell all the Levi's tshirt today. How much revenue will our store generate?",
        "SQLQuery":"""SELECT SUM(t_shirts.price * t_shirts.stock_quantity)
                      FROM t_shirts
                      WHERE t_shirts.brand = 'Levi'; """,
        "SQLResult":"Result of the SQL Query",
        "Answer":
    },
    {
        "Question":"How many White coloured Levi's tshirts do we have in each size?",
        "SQLQuery":"""SELECT SUM(stock_quantity)
                      FROM t_shirts
                      WHERE brand = 'Levi' AND color = 'White'; """,
        "SQLResult":"Result of the SQL Query",
        "Answer":
    }
]