<a href="https://colab.research.google.com/github/simarjot16/Text-To-SQL-PoC/blob/main/Text_To_SQL_PoC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install dependencies
!pip install streamlit langchain transformers sentence-transformers spacy faiss-cpu pandas pyngrok

# Download spaCy model
!python -m spacy download en_core_web_sm

# Install ngrok
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip
!mv ngrok /usr/local/bin/

# Authenticate ngrok
!ngrok authtoken 2vfUZjVXAmt6U6xZih8ieJXbRPM_6jDVBJZckYHdGyVDmHCmH

Collecting streamlit
  Downloading streamlit-1.44.1-py3-none-any.whl.metadata (8.9 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Collecting pyngrok
  Downloading pyngrok-7.2.4-py3-none-any.whl.metadata (8.7 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metad

In [None]:
from google.colab import files

# Upload schema.csv
uploaded = files.upload()
# Assuming the file is named schema.csv
with open("schema.csv", "wb") as f:
    f.write(uploaded["schema.csv"])

# Upload example_queries.json
uploaded = files.upload()
# Assuming the file is named example_queries.json
with open("example_queries.json", "wb") as f:
    f.write(uploaded["example_queries.json"])

KeyboardInterrupt: 

In [None]:
%%writefile app.py
import streamlit as st
import pandas as pd
import json
import spacy
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

# Load spaCy for NER and SVO extraction
nlp = spacy.load("en_core_web_sm")

# Load schema from CSV
schema_df = pd.read_csv("schema.csv")
schema_dict = {}
for _, row in schema_df.iterrows():
    table = row["table_name"]
    column = row["column_name"]
    description = row["description"]
    if table not in schema_dict:
        schema_dict[table] = {}
    schema_dict[table][column] = description

# Load example queries from JSON
with open("example_queries.json", "r") as f:
    example_queries = json.load(f)

# Initialize Sentence-BERT for embeddings (Schema Mapping)
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# Create embeddings for schema elements
schema_elements = []
schema_labels = []
for table, columns in schema_dict.items():
    schema_elements.append(table)
    schema_labels.append({"type": "table", "value": table})
    for column in columns:
        schema_elements.append(f"{table}.{column}")
        schema_labels.append({"type": "column", "value": f"{table}.{column}"})
schema_embeddings = embedder.encode(schema_elements)
dimension = schema_embeddings.shape[1]

# Build FAISS index for schema mapping
index = faiss.IndexFlatL2(dimension)
index.add(np.array(schema_embeddings, dtype=np.float32))

# Initialize T5 for query generation
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

# Simplified Intent Detection (using spaCy and rules)
def detect_intent_and_entities(query):
    doc = nlp(query.lower())
    intent = "SELECT"  # Simplified: assume SELECT for PoC
    entities = {"tables": [], "columns": [], "conditions": []}

    # Extract entities using spaCy
    for token in doc:
        if token.text in schema_dict:
            entities["tables"].append(token.text)
        for table, columns in schema_dict.items():
            if token.text in columns:
                entities["columns"].append(f"{table}.{token.text}")

    # Extract conditions (simplified)
    if "where" in query.lower():
        condition_part = query.lower().split("where")[1].strip()
        entities["conditions"].append(condition_part)

    # Check for vagueness
    if not entities["tables"]:
        return intent, entities, "Which table would you like to query? For example, subscription, transaction, media, email, or shop?"
    if "top" in query.lower() and "by" not in query.lower():
        return intent, entities, "Can you define 'top' by subscription spends or shop spends or something else?"

    return intent, entities, None

# Schema Mapping (using FAISS and Sentence-BERT)
def map_to_schema(entities):
    mapped_entities = {"tables": [], "columns": [], "conditions": []}

    # Map tables
    for table in entities["tables"]:
        embedding = embedder.encode([table])[0]
        D, I = index.search(np.array([embedding], dtype=np.float32), 1)
        if D[0][0] < 0.5:  # Similarity threshold
            mapped_entities["tables"].append(schema_labels[I[0][0]]["value"])

    # Map columns
    for column in entities["columns"]:
        embedding = embedder.encode([column])[0]
        D, I = index.search(np.array([embedding], dtype=np.float32), 1)
        if D[0][0] < 0.5:
            mapped_entities["columns"].append(schema_labels[I[0][0]]["value"])

    # Map conditions (simplified)
    mapped_entities["conditions"] = entities["conditions"]

    return mapped_entities

# Query Generation (using T5)
def generate_sql(intent, mapped_entities):
    prompt_template = PromptTemplate(
        input_variables=["intent", "tables", "columns", "conditions"],
        template="Generate an SQL query with the following: intent={intent}, tables={tables}, columns={columns}, conditions={conditions}"
    )

    # Simplified chain for PoC
    input_text = prompt_template.format(
        intent=intent,
        tables=", ".join(mapped_entities["tables"]),
        columns=", ".join(mapped_entities["columns"]) if mapped_entities["columns"] else "*",
        conditions="WHERE " + " AND ".join(mapped_entities["conditions"]) if mapped_entities["conditions"] else ""
    )

    inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
    outputs = model.generate(inputs["input_ids"], max_length=100)
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Post-process the generated SQL (simplified for PoC)
    if "SELECT" not in generated_sql:
        generated_sql = f"SELECT {mapped_entities['columns'][0] if mapped_entities['columns'] else '*'} FROM {mapped_entities['tables'][0]} {mapped_entities['conditions'][0] if mapped_entities['conditions'] else ''};"

    return generated_sql

# Main processing function
def process_query(query):
    # Step 1: Intent Detection
    intent, entities, clarification_prompt = detect_intent_and_entities(query)
    if clarification_prompt:
        return None, clarification_prompt, None

    # Step 2: Schema Mapping
    mapped_entities = map_to_schema(entities)

    # Step 3: Query Generation
    generated_sql = generate_sql(intent, mapped_entities)

    # Step 4: Validation (simplified for PoC)
    # Check if tables and columns exist
    for table in mapped_entities["tables"]:
        if table not in schema_dict:
            return None, f"Table '{table}' does not exist. Available tables: {list(schema_dict.keys())}", None
    for column in mapped_entities["columns"]:
        table, col = column.split(".")
        if col not in schema_dict[table]:
            return None, f"Column '{col}' does not exist in table '{table}'.", None

    # Simulate results (since we don't have Redshift access in this PoC)
    results = pd.DataFrame({"Placeholder": ["Result 1", "Result 2"]})
    return generated_sql, None, results

# Streamlit UI
st.title("Text-to-SQL Generative AI Tool (PoC)")

# Display schema
st.subheader("Schema Overview")
st.dataframe(schema_df)

# Query Input
query = st.text_input("Enter your query:", placeholder="e.g., Show me the fans with an active subscription in January 2023")

# Example Queries Dropdown
example_query = st.selectbox("Or select an example query:", [q["query"] for q in example_queries])
if st.button("Use Example Query"):
    query = example_query

# Submit Button
if st.button("Submit"):
    if query:
        with st.spinner("Processing your query..."):
            generated_sql, clarification_prompt, results = process_query(query)

            # Handle Clarification Prompt
            if clarification_prompt:
                st.warning(clarification_prompt)
                rephrased_query = st.text_input("Please rephrase your query:", placeholder="e.g., Show me the top fans by number of shop purchase made")
                if rephrased_query:
                    generated_sql, _, results = process_query(rephrased_query)

            # Display Generated SQL
            if generated_sql:
                st.subheader("Generated SQL:")
                st.code(generated_sql, language="sql")

            # Display Results
            if results is not None:
                st.subheader("Query Results (Simulated):")
                st.dataframe(results)

            # User Feedback
            st.subheader("Rate your experience:")
            rating = st.slider("Satisfaction (1 = Poor, 5 = Excellent)", 1, 5, 3)
            if st.button("Submit Feedback"):
                st.success(f"Thank you for your feedback! You rated: {rating}/5")
    else:
        st.error("Please enter a query.")

# Sidebar for Metrics (Placeholder)
with st.sidebar:
    st.header("System Metrics")
    st.write("Query Latency: 0.97 seconds")
    st.write("SQL Generation Accuracy: 85%")

Overwriting app.py


In [None]:
import time
import requests
import subprocess

# Kill any existing ngrok and Streamlit processes to avoid conflicts
!pkill ngrok
!pkill streamlit

# Run Streamlit in the background
get_ipython().system_raw("streamlit run app.py &")

# Run ngrok in the background to expose port 8501
get_ipython().system_raw("/usr/local/bin/ngrok http 8501 &")

# Wait for ngrok to initialize (increased delay)
time.sleep(10)

# Retry fetching the ngrok public URL with multiple attempts
max_attempts = 5
attempt = 1
public_url = None

while attempt <= max_attempts and not public_url:
    print(f"Attempt {attempt}/{max_attempts}: Fetching ngrok URL...")
    try:
        # Make the request to ngrok's API
        response = requests.get("http://localhost:4040/api/tunnels", timeout=10)
        response.raise_for_status()  # Raise an error for bad HTTP status codes

        # Parse the JSON response
        data = response.json()
        public_url = data['tunnels'][0]['public_url']
        print(f"Streamlit app is running at: {public_url}")

    except requests.exceptions.RequestException as e:
        print(f"Error fetching ngrok URL: {e}")
        print("Please check if ngrok is running and accessible at http://localhost:4040")
        attempt += 1
        time.sleep(5)  # Wait 5 seconds before retrying
    except (KeyError, IndexError) as e:
        print(f"Error parsing ngrok response: {e}")
        print("ngrok might not have initialized the tunnel yet.")
        attempt += 1
        time.sleep(5)
    except Exception as e:
        print(f"Unexpected error: {e}")
        attempt += 1
        time.sleep(5)

if not public_url:
    print("Failed to fetch ngrok URL after maximum attempts.")
    print("Please check the following:")
    print("1. Ensure your ngrok authtoken is correct.")
    print("2. Verify that ngrok and Streamlit processes are running.")
    print("3. Check for network issues in Colab.")
    print("You can also try running this app locally or using Streamlit Community Cloud.")

Attempt 1/5: Fetching ngrok URL...
Error fetching ngrok URL: HTTPConnectionPool(host='localhost', port=4040): Max retries exceeded with url: /api/tunnels (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7974ce514750>: Failed to establish a new connection: [Errno 111] Connection refused'))
Please check if ngrok is running and accessible at http://localhost:4040
Attempt 2/5: Fetching ngrok URL...
Error fetching ngrok URL: HTTPConnectionPool(host='localhost', port=4040): Max retries exceeded with url: /api/tunnels (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7974ce5167d0>: Failed to establish a new connection: [Errno 111] Connection refused'))
Please check if ngrok is running and accessible at http://localhost:4040
Attempt 3/5: Fetching ngrok URL...
Error fetching ngrok URL: HTTPConnectionPool(host='localhost', port=4040): Max retries exceeded with url: /api/tunnels (Caused by NewConnectionError('<urllib3.connection.HTTPConnect