In [36]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [1]:
!pip3 install streamlit langchain openai langchain-openai langchain-community



In [2]:
import re
import warnings
import logging
import os
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
    QuerySQLDataBaseTool,
)
from langchain.agents import Tool
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.agent_toolkits.sql.base import create_sql_agent
from langchain.agents import Tool as LangchainTool
import sqlite3
import streamlit as st
from datetime import datetime

In [21]:
import warnings
import logging
for name, l in logging.root.manager.loggerDict.items():
    if "streamlit" in name:
        l.disabled = True

# Suppress specific warning about ScriptRunContext
warnings.filterwarnings("ignore", message=".*ScriptRunContext.*")

# Suppress all other warnings
warnings.filterwarnings("ignore")

# Suppress Streamlit internal logs
logging.getLogger("streamlit.runtime.scriptrunner").setLevel(logging.ERROR)

# Suppress all other logs
logging.getLogger().setLevel(logging.ERROR)

In [24]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("bizzyvinci/coinmarketcap-historical-data")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/coinmarketcap-historical-data


In [9]:
!ls -ltr /kaggle/input/coinmarketcap-historical-data/coinmarketcap.sqlite
!cp /kaggle/input/coinmarketcap-historical-data/coinmarketcap.sqlite /content/coinmarketcap.sqlite
!ls -ltr /content/coinmarketcap.sqlite

-rw-r--r-- 1 1000 1000 848859136 Jul  6 16:52 /kaggle/input/coinmarketcap-historical-data/coinmarketcap.sqlite


In [25]:
os.environ["OPENAI_API_KEY"] = "key"

In [26]:
def strict_sql_generator(question: str) -> str:
    """
    Generate a valid SQL query from a natural language question using an LLM.

    This function uses the current database schema and a language model (ChatOpenAI)
    to translate a user-provided natural language question into a syntactically correct
    SQL query. It strictly enforces that the generated SQL must only use tables and
    columns present in the provided schema and must not include any explanations,
    comments, or markdown formatting.

    Args:
        question (str): The natural language question to translate into SQL.

    Returns:
        str: A raw SQL query as a plain string.
    """
    llm = ChatOpenAI(temperature=0)
    prompt = (
        "You are a SQL expert.\n"
        "Use only the following database schema to answer the question.\n\n"
        f"{schema}\n\n"
        "Translate the question into a valid SQL query.\n"
        "Do NOT make up tables or columns.\n"
        "Return ONLY the SQL query. Do NOT wrap it in markdown or provide explanation.\n\n"
        f"Question: {question}\nSQL:"
    )
    return llm.predict(prompt)

In [27]:
def fix_invalid_sql(sql: str, error_msg: str) -> str:
    """
    Attempts to fix an invalid SQL query by prompting a language model with the error context.

    Given a faulty SQL statement and the associated error message, this function uses an LLM
    (e.g., ChatOpenAI) to suggest a corrected version of the query. The prompt instructs the model
    to return only the fixed SQL query, with no additional explanation or formatting.

    Args:
        sql (str): The original, invalid SQL query.
        error_msg (str): The error message generated when the SQL was executed.

    Returns:
        str: A corrected SQL query as a plain string.
    """
    prompt = (
        f"This SQL caused an error: `{sql}`\n"
        f"Error: {error_msg}\n"
        f"Fix the SQL and return only the corrected query."
    )
    return llm.predict(prompt)

In [28]:
def extract_sql_from_response(text):
    """
    Extracts a raw SQL query from a text response, typically returned by a language model.

    This function searches for SQL code enclosed in markdown-style triple backticks (e.g., ```sql ... ```)
    and extracts the SQL statement. If no code block is found, it returns the entire input text stripped
    of leading and trailing whitespace.

    Args:
        text (str): The text response potentially containing SQL inside a markdown code block.

    Returns:
        str: The extracted SQL query, or the plain text if no code block is found.
    """
    match = re.search(r"```(?:sql)?\s*(.*?)```", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return text.strip()

In [29]:
def get_actual_tables(db: SQLDatabase):
    """
    Retrieves the list of actual table names from a SQLite database.

    Connects to the underlying SQLite engine using the LangChain SQLDatabase object,
    queries the `sqlite_master` table to get all table names, and returns them as a set
    of lowercase strings.

    Args:
        db (SQLDatabase): An instance of LangChain's SQLDatabase connected to a SQLite database.

    Returns:
        set: A set containing the names of all tables in the database, in lowercase.
    """
    conn = db._engine.raw_connection()
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0].lower() for row in cursor.fetchall()]
    conn.close()
    return set(tables)

In [30]:
def validate_sql_tables(sql_query, db: SQLDatabase):
    """
    Validates whether the SQL query references only existing tables in the database.

    Extracts table names used in the SQL query (from FROM and JOIN clauses), compares
    them against the actual tables present in the database, and identifies any missing
    or invalid table references.

    Args:
        sql_query (str): The SQL query to validate.
        db (SQLDatabase): An instance of LangChain's SQLDatabase connected to the target SQLite database.

    Returns:
        set: A set of table names that are referenced in the SQL query but do not exist in the database.
    """
    actual_tables = get_actual_tables(db)
    sql_tables = set(filter(None, sum(extract_tables_from_sql(sql_query), ())))
    missing = sql_tables - actual_tables
    return missing

In [31]:
def extract_tables_from_sql(sql_query):
    """
    Extracts table names used in a SQL query.

    This function identifies and returns all table names that appear after
    'FROM' or 'JOIN' keywords in the SQL query using regular expression matching.

    Args:
        sql_query (str): A SQL query string.

    Returns:
        list of tuple: A list of tuples containing table names found after 'FROM' or 'JOIN'.
                       Each tuple contains two elements, with one of them being None
                       depending on which keyword matched.
    """
    return re.findall(r"FROM\s+(\w+)|JOIN\s+(\w+)", sql_query, re.IGNORECASE)


In [32]:
def auto_correct_sql(sql_query, schema, missing_tables):
    """
    Attempts to automatically correct a SQL query by leveraging an LLM to fix invalid table references.

    Constructs a prompt that explains the SQL error and provides the valid database schema,
    then asks the LLM to return a corrected SQL query using only valid tables and columns.

    Args:
        sql_query (str): The original SQL query with invalid or missing table references.
        schema (str): The full schema of the valid database tables and columns.
        missing_tables (set): A set of table names that were not found in the database.

    Returns:
        str: A corrected SQL query as suggested by the LLM, typically enclosed in a markdown code block.
    """
    correction_prompt = (
        "You have generated the following SQL query, but it uses missing or invalid tables:\n\n"
        f"{sql_query}\n\n"
        f"Missing tables: {', '.join(missing_tables)}\n\n"
        f"Here is the valid database schema:\n{schema}\n\n"
        "Please return a corrected SQL query using only valid tables and columns. Return only SQL in a markdown code block."
    )
    return llm.predict(correction_prompt)

In [39]:
db = SQLDatabase.from_uri("sqlite:///coinmarketcap.sqlite")
schema = db.get_table_info()
llm = ChatOpenAI(temperature=0, model="gpt-4o")  # or "gpt-4" or "gpt-3.5-turbo"
db_path = "/content/coinmarketcap.sqlite"
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
tools = [
    QuerySQLDataBaseTool(db=db),
    InfoSQLDatabaseTool(db=db),
    ListSQLDatabaseTool(db=db),
    LangchainTool(
        name="Explain SQL",
        func=lambda sql: ChatOpenAI(temperature=0).predict(f"Explain this SQL: {sql}"),
        description="Explains the meaning of a given SQL query",
    ),
    LangchainTool(
        name="StrictSQLGenerator",
        func=strict_sql_generator,
        description="NLQ -> SQL only",
    ),
]
strict_sql_tool = Tool(
    name="StrictSQLGenerator",
    func=strict_sql_generator,
    description="Takes a natural language question and returns only the SQL query",
)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

agent_executor = create_sql_agent(
    tools=[strict_sql_tool],
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    handle_parsing_errors=True,
)

In [37]:
user_question = "What was the highest price of Bitcoin in January 2020?"

if user_question:
      agent_prompt = (
          "You are a SQL expert.\n"
          "Use only the following database schema to answer the question.\n\n"
          f"{schema}\n\n"
          "Translate the question into a valid SQL query.\n"
          "Do NOT make up tables or columns.\n"
          "Return ONLY the SQL query. Do NOT wrap it in markdown or provide explanation.\n\n"
          f"Question: {user_question}\nSQL:"
      )
      raw_output = llm.predict(agent_prompt)
      generated_sql = extract_sql_from_response(raw_output)

raw_sql = strict_sql_generator(user_question)
sql = extract_sql_from_response(raw_sql)
st.text_area("Generated SQL", sql)

missing_tables = validate_sql_tables(generated_sql, db)
if missing_tables:
  corrected_sql = auto_correct_sql(generated_sql, schema, missing_tables)
  corrected_sql_clean = extract_sql_from_response(corrected_sql)
  if sql.lower().startswith("-- no valid sql"):
      print("⚠️ The model couldn't generate a SQL query based on your question.")
  else:
      print("✅ Corrected SQL:", corrected_sql_clean)
else:
    try:
        conn = sqlite3.connect(db_path)
        cur = conn.cursor()
        cur.execute(sql)
        rows = cur.fetchall()
        col_names = [desc[0] for desc in cur.description]
        conn.close()
        print("✅ Query executed successfully")
        result = [dict(zip(col_names, row)) for row in rows]
        print("Result:",result)
        explain_prompt = f"Explain what the following SQL query does:\n\n{sql}"
        explanation = llm.predict(explain_prompt)
        print("Query Explanation",explanation)
    except Exception as e:
        print(f"❌ SQL Error: {e}")
        fixed_sql = fix_invalid_sql(generated_sql, str(e))
        st.warning("Suggested fix:")

✅ Query executed successfully
Result: [{'MAX(high)': 9553.12613251}]
Query Explanation This SQL query is designed to retrieve the highest value from a specific dataset. Here's a breakdown of what each part of the query does:

- `SELECT MAX(high)`: This part of the query is selecting the maximum value from the column named `high`. The `MAX()` function is an aggregate function that returns the largest value in a specified column.

- `FROM historical`: This specifies the table from which to retrieve the data. In this case, the table is named `historical`.

- `WHERE coin_id = 1`: This condition filters the rows to include only those where the `coin_id` column has a value of 1. This implies that the query is interested in data related to a specific coin or asset identified by the ID 1.

- `AND date >= '2020-01-01' AND date <= '2020-01-31'`: These conditions further filter the data to include only the rows where the `date` column falls within the specified range, from January 1, 2020, to Jan