In [None]:
import sqlite3
import ollama
from langchain_community.llms import Ollama
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

llm = Ollama(model="llama3")

db_schema = """
The database has the following tables:
1. 'sweets' table with columns: sweet_name, sweet_type, quantity_in_stock, price, sweetness, freshness.
2. 'special_orders' table with columns: order_id, customer_name, sweet_name, quantity, order_date, pickup_date, instructions.
3. 'discounts' table with columns: discount_id, order_id, discount_type, discount_value, description.
"""


prompt = PromptTemplate(
    template="""
    You are a helpful assistant that can generate SQL queries based on user questions.
    Your goal is to convert the user's request into a single, valid SQL query.
    {db_schema}

    User question: {user_question}
    SQL query:
    """,
    input_variables=["db_schema", "user_question"]
)


sql_chain = prompt | llm | StrOutputParser()


def execute_sql_query(query):
    conn = sqlite3.connect('sweet_shop.db')
    cursor = conn.cursor()
    try:
        cursor.execute(query)
        result = cursor.fetchall()
        return result
    except sqlite3.OperationalError as e:
        return print(f"Error executing SQL query: {e}")

  llm = Ollama(model="llama3")


In [None]:
# First, make sure you have the required libraries installed:
%pip install ollama langchain_community langchain_core

Collecting ollama
  Downloading ollama-0.5.3-py3-none-any.whl.metadata (4.3 kB)
Collecting langchain_community
  Downloading langchain_community-0.3.29-py3-none-any.whl.metadata (2.9 kB)
Collecting langchain_core
  Downloading langchain_core-0.3.75-py3-none-any.whl.metadata (5.7 kB)
Collecting requests<3,>=2.32.5 (from langchain_community)
  Downloading requests-2.32.5-py3-none-any.whl.metadata (4.9 kB)
Collecting dataclasses-json<0.7,>=0.6.7 (from langchain_community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.6.7->langchain_community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.6.7->langchain_community)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.6.7->langchain_community)
  Downloadin

In [None]:
# --- 1. Import Necessary Libraries ---
import sqlite3
import ollama
from langchain_community.llms import Ollama
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

# --- 2. Database Setup Function ---
# This function creates and populates the SQLite database file.
def setup_database():
    """
    Sets up the database schema and populates it with initial data.
    This mimics the SQL script you provided, but for SQLite.
    """
    conn = sqlite3.connect('sweet_shop.db')
    cursor = conn.cursor()

    cursor.executescript("""
        CREATE TABLE IF NOT EXISTS sweets (
            sweet_name TEXT PRIMARY KEY,
            sweet_type TEXT NOT NULL,
            quantity_in_stock INTEGER NOT NULL,
            price REAL NOT NULL,
            sweetness REAL NOT NULL,
            freshness REAL NOT NULL
        );

        INSERT OR IGNORE INTO sweets (sweet_name, sweet_type, quantity_in_stock, price, sweetness, freshness) VALUES
        ('Chocolate Lava Cupcake', 'Cupcake', 50, 3.50, 8.5, 9.2),
        ('Oatmeal Raisin Cookie', 'Cookie', 75, 2.00, 6.0, 8.8),
        ('Fudgy Walnut Brownie', 'Brownie', 40, 4.00, 9.0, 9.5),
        ('Red Velvet Cake Pop', 'Cake Pop', 60, 2.50, 7.5, 9.0),
        ('Pistachio Macaron', 'Macaron', 80, 1.75, 7.0, 9.8);

        CREATE TABLE IF NOT EXISTS special_orders (
            order_id INTEGER PRIMARY KEY,
            customer_name TEXT NOT NULL,
            sweet_name TEXT,
            quantity INTEGER NOT NULL,
            order_date TEXT NOT NULL,
            pickup_date TEXT NOT NULL,
            instructions TEXT,
            FOREIGN KEY (sweet_name) REFERENCES sweets(sweet_name)
        );

        INSERT OR IGNORE INTO special_orders (customer_name, sweet_name, quantity, order_date, pickup_date, instructions) VALUES
        ('Jane Doe', 'Chocolate Lava Cupcake', 24, '2024-10-26', '2024-10-29', 'Include a "Happy Birthday" note.'),
        ('John Smith', 'Pistachio Macaron', 50, '2024-10-27', '2024-10-30', 'Tie with a ribbon, no box.'),
        ('Emily White', NULL, 1, '2024-10-28', '2024-10-31', 'Custom cake: 3-tier vanilla, frosting with blue flowers.');

        CREATE TABLE IF NOT EXISTS discounts (
            discount_id INTEGER PRIMARY KEY,
            order_id INTEGER NOT NULL,
            discount_type TEXT NOT NULL,
            discount_value REAL NOT NULL,
            description TEXT,
            FOREIGN KEY (order_id) REFERENCES special_orders(order_id)
        );

        INSERT OR IGNORE INTO discounts (order_id, discount_type, discount_value, description) VALUES
        (1, 'Percentage', 10.00, 'Loyalty discount for a large order.'),
        (2, 'Fixed Amount', 5.00, 'Promotional discount.');
    """)
    conn.commit()
    conn.close()

# --- 3. LLM Setup and Prompt Template ---
# Initialize the local LLM. Make sure 'llama3' is pulled in Ollama.
llm = Ollama(model="llama3")

# Define the database schema information for the LLM
db_schema = """
The database has the following tables:
1. 'sweets' table with columns: sweet_name, sweet_type, quantity_in_stock, price, sweetness, freshness.
2. 'special_orders' table with columns: order_id, customer_name, sweet_name, quantity, order_date, pickup_date, instructions.
3. 'discounts' table with columns: discount_id, order_id, discount_type, discount_value, description.
"""

# Create a prompt template for the LLM to guide its response
prompt = PromptTemplate(
    template="""
    You are a helpful assistant that can generate a valid SQLite SQL query based on a user's question and the provided database schema.
    Your goal is to convert the user's request into a single, valid SQL query.

    Database schema:
    {db_schema}

    User question: {user_question}
    SQL query:
    """,
    input_variables=["db_schema", "user_question"]
)

# Create a processing chain with LangChain
sql_chain = prompt | llm | StrOutputParser()

# --- 4. SQL Query Execution Function ---
# This function runs the query generated by the LLM against the database.
def execute_sql_query(query):
    """
    Executes a given SQL query on the sweet_shop.db database.
    """
    conn = sqlite3.connect('sweet_shop.db')
    cursor = conn.cursor()
    try:
        # Use a more secure execution method if queries were not from a trusted source
        cursor.execute(query)
        result = cursor.fetchall()
        return result
    except sqlite3.OperationalError as e:
        return f"Error: {e}"
    finally:
        conn.close()

# --- 5. Main Chatbot Logic ---
def main():
    """
    Main function to run the chatbot logic.
    """
    # Ensure the database is set up before running any queries
    setup_database()

    # Example user queries
    queries = [
        "What is the price of Red Velvet Cake Pop?",
        "How many Oatmeal Raisin Cookies are in stock?",
        "What are the names of all the sweets?",
        "What are the details of all special orders?",
    ]

    for user_query in queries:
        print("\n" + "="*50)
        print(f"User: {user_query}")

        # Pass the user query to the LLM to get the SQL statement
        print("Generating SQL query with LLM...")

        try:
            generated_sql = sql_chain.invoke({"db_schema": db_schema, "user_question": user_query})

            print(f"Generated SQL: {generated_sql}")

            # Execute the generated SQL query
            query_result = execute_sql_query(generated_sql)

            # Format the result into a human-readable response
            response = f"Query Result: {query_result}"

            print(f"Chatbot: {response}")

        except Exception as e:
            print(f"An error occurred: {e}")

if __name__ == "__main__":
    main()


User: What is the price of Red Velvet Cake Pop?
Generating SQL query with LLM...
An error occurred: HTTPConnectionPool(host='localhost', port=11434): Max retries exceeded with url: /api/generate (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7bb6e05f5fa0>: Failed to establish a new connection: [Errno 111] Connection refused'))

User: How many Oatmeal Raisin Cookies are in stock?
Generating SQL query with LLM...
An error occurred: HTTPConnectionPool(host='localhost', port=11434): Max retries exceeded with url: /api/generate (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7bb6e05f6930>: Failed to establish a new connection: [Errno 111] Connection refused'))

User: What are the names of all the sweets?
Generating SQL query with LLM...
An error occurred: HTTPConnectionPool(host='localhost', port=11434): Max retries exceeded with url: /api/generate (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7bb6