Module 1: Cloud Data Source Connector Framework
Focus:

Connect to AdventureWorks (PostgreSQL or SQL Server)

Map schema into a clean JSON abstraction

Extract basic metadata: tables, columns, relationships

In [None]:
!pip install -q sentence-transformers faiss-cpu


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m62.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m61.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!pip install -q langchain


In [None]:
!pip install -U langchain-community

Collecting langchain-community
  Downloading langchain_community-0.3.23-py3-none-any.whl.metadata (2.5 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.9.1-py3-none-any.whl.metadata (3.8 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting python-dotenv>=0.21.0 (from pydantic-settings<3.0.0,>=2.4.0->langchain-community)
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB

In [None]:
!pip install groq

Collecting groq
  Downloading groq-0.24.0-py3-none-any.whl.metadata (15 kB)
Downloading groq-0.24.0-py3-none-any.whl (127 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.5/127.5 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: groq
Successfully installed groq-0.24.0


In [None]:
biz_question = "Which country has the most customers?"

In [None]:
import sqlite3
import json
import re
import os



# STEP 2: Point to the database file path
db_path = "/content/Chinook_Sqlite.sqlite"
assert os.path.exists(db_path), "Database file not found at /content"

# STEP 3: Connect and process
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

# Get all tables
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]

def beautify_name(name):
    name = name.replace("_", " ").replace("ID", "Id")
    return re.sub(r'(?<!^)(?=[A-Z])', ' ', name).title()

def generate_column_description(name):
    return f"{beautify_name(name)} of the record."

enriched_schema = []

for table in tables:
    # Columns
    cursor.execute(f"PRAGMA table_info('{table}')")
    cols = cursor.fetchall()

    # Foreign keys
    cursor.execute(f"PRAGMA foreign_key_list('{table}')")
    fks = cursor.fetchall()

    # Row count
    try:
        cursor.execute(f"SELECT COUNT(*) FROM '{table}'")
        row_count = cursor.fetchone()[0]
    except:
        row_count = None

    enriched_schema.append({
        "table_name": table,
        "business_name": beautify_name(table),
        "description": f"Contains information related to {beautify_name(table)}.",
        "row_count": row_count,
        "columns": [{
            "column_name": col[1],
            "business_name": beautify_name(col[1]),
            "data_type": col[2],
            "is_primary_key": bool(col[5]),
            "description": generate_column_description(col[1])
        } for col in cols],
        "foreign_keys": [{
            "from_column": fk[3],
            "to_table": fk[2],
            "to_column": fk[4]
        } for fk in fks]
    })

conn.close()

# STEP 4: Save output
output_path = "/content/schema_metadata_enriched.json"
with open(output_path, "w") as f:
    json.dump(enriched_schema, f, indent=2)

print(f"Enriched schema saved to {output_path}")


Enriched schema saved to /content/schema_metadata_enriched.json


In [None]:
import json
import re

# Load previously enriched schema
with open("/content/schema_metadata_enriched.json", "r") as f:
    enriched_schema = json.load(f)

# Sample term dictionary
business_term_dict =  {
    # Album/Artist
    "AlbumId": "Album Identifier",
    "Title (Album)": "Album Title",
    "ArtistId": "Artist Identifier",
    "Name (Artist)": "Artist Name",

    # Customer
    "CustomerId": "Customer Identifier",
    "FirstName (Customer)": "First Name",
    "LastName (Customer)": "Last Name",
    "Company (Customer)": "Customer Company Name",
    "Address (Customer)": "Customer Address",
    "City (Customer)": "Customer City",
    "State (Customer)": "Customer State",
    "Country (Customer)": "Customer Country",
    "PostalCode (Customer)": "Customer Postal Code",
    "Phone (Customer)": "Customer Phone Number",
    "Fax (Customer)": "Customer Fax Number",
    "Email (Customer)": "Customer Email Address",
    "SupportRepId": "Customer Support Representative Identifier",

    # Employee
    "EmployeeId": "Employee Identifier",
    "LastName (Employee)": "Last Name",
    "FirstName (Employee)": "First Name",
    "Title (Employee)": "Employee Job Title",
    "ReportsTo": "Employee Manager Identifier",
    "BirthDate": "Employee Birth Date",
    "HireDate": "Employee Hire Date",
    "Address (Employee)": "Employee Address",
    "City (Employee)": "Employee City",
    "State (Employee)": "Employee State",
    "Country (Employee)": "Employee Country",
    "PostalCode (Employee)": "Employee Postal Code",
    "Phone (Employee)": "Employee Phone Number",
    "Fax (Employee)": "Employee Fax Number",
    "Email (Employee)": "Employee Email Address",

    # Genre/Media Type
    "GenreId": "Genre Identifier",
    "Name (Genre)": "Genre Name",
    "MediaTypeId": "Media Type Identifier",
    "Name (MediaType)": "Media Type Name",

    # Invoice
    "InvoiceId": "Invoice Identifier",
    "CustomerId (Invoice)": "Customer Identifier",
    "InvoiceDate": "Invoice Date",
    "BillingAddress": "Invoice Billing Address",
    "BillingCity": "Invoice Billing City",
    "BillingState": "Invoice Billing State",
    "BillingCountry": "Invoice Billing Country",
    "BillingPostalCode": "Invoice Billing Postal Code",
    "Total": "Invoice Total Amount",

    # Invoice Line
    "InvoiceLineId": "Invoice Line Identifier",
    "InvoiceId (InvoiceLine)": "Invoice Identifier",
    "TrackId (InvoiceLine)": "Track Identifier",
    "UnitPrice (InvoiceLine)": "Invoice Line Unit Price",
    "Quantity": "Invoice Line Quantity",

    # Playlist/Track
    "PlaylistId": "Playlist Identifier",
    "Name (Playlist)": "Playlist Name",
    "TrackId (PlaylistTrack)": "Track Identifier",
    "Name (Track)": "Track Name",
    "Composer": "Track Composer",
    "Milliseconds": "Track Duration (Milliseconds)",
    "Bytes": "Track File Size (Bytes)",
    "UnitPrice (Track)": "Track Unit Price",
}

# Print the dictionary
for term, definition in business_term_dict.items():
    print(f"{term}: {definition}")

# Normalize keys for matching
normalized_dict = {k.lower(): v for k, v in business_term_dict.items()}

# Enrich business_name, tagging, and description
for table in enriched_schema:
    table_name = table["table_name"].lower()

    for col in table["columns"]:
        col_key = col["column_name"].lower().replace("_", "")
        business_term = normalized_dict.get(col_key)

        if business_term:
            col["business_name"] = business_term
            col["business_term_tagged"] = True
            col["description"] = f"{business_term} of the {table['table_name']}."
        else:
            col["business_name"] = col["column_name"]
            col["business_term_tagged"] = False

            # Context-sensitive description
            if "customer" in table_name:
                col["description"] = f"{col['column_name']} of the customer."
            elif "invoice" in table_name:
                col["description"] = f"{col['column_name']} related to the invoice."
            elif "track" in table_name:
                col["description"] = f"{col['column_name']} of the track or song."
            elif "employee" in table_name:
                col["description"] = f"{col['column_name']} of the employee."
            else:
                col["description"] = f"{col['column_name']} of the record."
            # Context-sensitive description
            if "customer" in table_name:
              col["description"] = f"{col['column_name']} of the customer."
            elif "invoice" in table_name:
              col["description"] = f"{col['column_name']} related to the invoice."
            elif "track" in table_name:
              col["description"] = f"{col['column_name']} of the track or song."
            elif "employee" in table_name:
              col["description"] = f"{col['column_name']} of the employee."
            elif "artist" in table_name:
              col["description"] = f"{col['column_name']} related to the artist or band."
            elif "album" in table_name:
                col["description"] = f"{col['column_name']} related to the album."
            elif "playlist" in table_name:
                col["description"] = f"{col['column_name']} related to the playlist."
            elif "mediatype" in table_name:
                col["description"] = f"{col['column_name']} describing the media format."
            elif "genre" in table_name:
                col["description"] = f"{col['column_name']} representing the music genre."
            else:
                col["description"] = f"{col['column_name']} of the record."



# Save updated schema
output_path = "/content/schema_metadata_enriched_with_tags.json"
with open(output_path, "w") as f:
    json.dump(enriched_schema, f, indent=2)

print(f" Business term tagging completed and saved to {output_path}")


AlbumId: Album Identifier
Title (Album): Album Title
ArtistId: Artist Identifier
Name (Artist): Artist Name
CustomerId: Customer Identifier
FirstName (Customer): First Name
LastName (Customer): Last Name
Company (Customer): Customer Company Name
Address (Customer): Customer Address
City (Customer): Customer City
State (Customer): Customer State
Country (Customer): Customer Country
PostalCode (Customer): Customer Postal Code
Phone (Customer): Customer Phone Number
Fax (Customer): Customer Fax Number
Email (Customer): Customer Email Address
SupportRepId: Customer Support Representative Identifier
EmployeeId: Employee Identifier
LastName (Employee): Last Name
FirstName (Employee): First Name
Title (Employee): Employee Job Title
ReportsTo: Employee Manager Identifier
BirthDate: Employee Birth Date
HireDate: Employee Hire Date
Address (Employee): Employee Address
City (Employee): Employee City
State (Employee): Employee State
Country (Employee): Employee Country
PostalCode (Employee): Emplo

In [None]:
import pprint

# Print each table's metadata cleanly
for table in enriched_schema:
    print(f"\n Table: {table['table_name']} ({table['business_name']})")
    pprint.pprint(table, indent=2)
    print("=" * 80)


 Table: Album (Album)
{ 'business_name': 'Album',
  'columns': [ { 'business_name': 'Album Identifier',
                 'business_term_tagged': True,
                 'column_name': 'AlbumId',
                 'data_type': 'INTEGER',
                 'description': 'Album Identifier of the Album.',
                 'is_primary_key': True},
               { 'business_name': 'Title',
                 'business_term_tagged': False,
                 'column_name': 'Title',
                 'data_type': 'NVARCHAR(160)',
                 'description': 'Title related to the album.',
                 'is_primary_key': False},
               { 'business_name': 'Artist Identifier',
                 'business_term_tagged': True,
                 'column_name': 'ArtistId',
                 'data_type': 'INTEGER',
                 'description': 'Artist Identifier of the Album.',
                 'is_primary_key': False}],
  'description': 'Contains information related to Album.',
  'foreign_key

# Module 3

In [None]:


from sentence_transformers import SentenceTransformer
import faiss
import json
import os


# Load enriched schema with tags
with open("schema_metadata_enriched_with_tags.json", "r") as f:
    schema = json.load(f)

# Initialize embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")

# Prepare text chunks and metadata
chunks = []
metadata = []

for table in schema:
    chunk_text = f"Table: {table['business_name']} ({table['table_name']})\n"
    chunk_text += f"Description: {table['description']}\n"
    chunk_text += "Columns:\n"
    for col in table["columns"]:
        chunk_text += f" - {col['business_name']} ({col['column_name']}): {col['description']}\n"
    if table["foreign_keys"]:
        chunk_text += "Relationships:\n"
        for fk in table["foreign_keys"]:
            chunk_text += f" - {fk['from_column']} → {fk['to_table']}.{fk['to_column']}\n"
    chunks.append(chunk_text)
    metadata.append({
        "table_name": table["table_name"],
        "business_name": table["business_name"],
        "num_columns": len(table["columns"]),
        "num_foreign_keys": len(table["foreign_keys"])
    })

# Generate embeddings
embeddings = model.encode(chunks)

# Create FAISS index
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document

# Example inputs: chunks and model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
docs = [Document(page_content=chunk) for chunk in chunks]

# Create vectorstore
vectorstore = FAISS.from_documents(docs, embedding_model)

# Save to a folder (not a file)
vectorstore.save_local("schema_faiss")

with open("schema_chunks.json", "w") as f:
    json.dump(chunks, f, indent=2)
with open("schema_metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print("FAISS index and chunk metadata saved.")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")


FAISS index and chunk metadata saved.


In [None]:


import json
import faiss
from sentence_transformers import SentenceTransformer
from google.colab import files

# Load the enriched schema
with open("schema_metadata_enriched_with_tags.json", "r") as f:
    schema = json.load(f)

# Load embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")

# Prepare column-level chunks
column_chunks = []
column_metadata = []

for table in schema:
    for col in table["columns"]:
        chunk_text = f"Table: {table['business_name']} ({table['table_name']})\n"
        chunk_text += f"Column: {col['business_name']} ({col['column_name']})\n"
        chunk_text += f"Type: {col['data_type']}\n"
        chunk_text += f"Description: {col['description']}\n"
        chunk_text += f"Primary Key: {'Yes' if col['is_primary_key'] else 'No'}\n"
        chunk_text += f"Business Term Tagged: {'Yes' if col.get('business_term_tagged') else 'No'}\n"
         # Inject foreign key info if applicable
        for fk in table["foreign_keys"]:
            if fk["from_column"] == col["column_name"]:
                chunk_text += f"Foreign Key: {fk['from_column']} → {fk['to_table']}.{fk['to_column']}\n"

        column_chunks.append(chunk_text)
        column_metadata.append({
            "table_name": table["table_name"],
            "column_name": col["column_name"],
            "business_name": col["business_name"],
            "is_primary_key": col["is_primary_key"]
        })

# Embed column-level chunks
column_embeddings = model.encode(column_chunks)

# Create FAISS index for column-level embeddings
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document

# Assume `column_chunks` contains your list of string chunks
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
column_docs = [Document(page_content=chunk) for chunk in column_chunks]

# Build and save
column_vectorstore = FAISS.from_documents(column_docs, embedding_model)
column_vectorstore.save_local("schema_column_faiss")  # <- save in a folder with this exact name


with open("column_chunks.json", "w") as f:
    json.dump(column_chunks, f, indent=2)
with open("column_metadata.json", "w") as f:
    json.dump(column_metadata, f, indent=2)

print("Column-level FAISS index and metadata saved.")


Column-level FAISS index and metadata saved.


In [None]:
for i, chunk in enumerate(column_chunks[:5]):
    print(f"\n--- Column Chunk {i+1} ---\n{chunk}")


--- Column Chunk 1 ---
Table: Album (Album)
Column: Album Identifier (AlbumId)
Type: INTEGER
Description: Album Identifier of the Album.
Primary Key: Yes
Business Term Tagged: Yes


--- Column Chunk 2 ---
Table: Album (Album)
Column: Title (Title)
Type: NVARCHAR(160)
Description: Title related to the album.
Primary Key: No
Business Term Tagged: No


--- Column Chunk 3 ---
Table: Album (Album)
Column: Artist Identifier (ArtistId)
Type: INTEGER
Description: Artist Identifier of the Album.
Primary Key: No
Business Term Tagged: Yes
Foreign Key: ArtistId → Artist.ArtistId


--- Column Chunk 4 ---
Table: Artist (Artist)
Column: Artist Identifier (ArtistId)
Type: INTEGER
Description: Artist Identifier of the Artist.
Primary Key: Yes
Business Term Tagged: Yes


--- Column Chunk 5 ---
Table: Artist (Artist)
Column: Name (Name)
Type: NVARCHAR(120)
Description: Name related to the artist or band.
Primary Key: No
Business Term Tagged: No



# Module 4

In [None]:
import json
import faiss
from sentence_transformers import SentenceTransformer

# Placeholder chunks representing business knowledge (simulated from genbi.pdf)
genbi_chunks = [
    # Business Rules
    "Business Rule: Employee Commission = Invoice.Total * Employee.CommissionRate where Invoice.SupportRepId = Employee.EmployeeId",
    "Business Rule: Track Profitability = (InvoiceLine.UnitPrice - Track.UnitCost) * InvoiceLine.Quantity",
    "Business Rule: Album Revenue = SUM(InvoiceLine.UnitPrice * InvoiceLine.Quantity) for all Tracks in an Album",
    "Business Rule: Artist Revenue = SUM(Album Revenue) for all Albums by an Artist",

    # KPIs
    "KPI: Customer Lifetime Value (CLV) = SUM(Invoice.Total) grouped by Customer.CustomerId",
    "KPI: Purchase Frequency = COUNT(Invoice.InvoiceId) / COUNT(DISTINCT Customer.CustomerId)",
    "KPI: Average Revenue Per Track = SUM(InvoiceLine.UnitPrice * InvoiceLine.Quantity) / COUNT(DISTINCT Track.TrackId)",
    "KPI: Employee Sales Performance = SUM(Invoice.Total) grouped by Employee.EmployeeId",
    "KPI: Genre Popularity = COUNT(InvoiceLine.InvoiceLineId) grouped by Genre.Name",

    # Term Mappings
    "Term Mapping: 'Sales' maps to 'Invoice.Total'",
    "Term Mapping: 'Purchase' maps to 'Invoice with associated InvoiceLines'",
    "Term Mapping: 'Song' maps to 'Track'",
    "Term Mapping: 'Customer Spend' maps to 'SUM(Invoice.Total) for a specific Customer'",
    "Term Mapping: 'Sales Rep' maps to 'Employee who is a SupportRep for Customer'",

    # Relationships & Join Paths
    "Relationship: Each Invoice belongs to exactly one Customer (Invoice.CustomerId -> Customer.CustomerId)",
    "Relationship: Each Customer is supported by one Employee (Customer.SupportRepId -> Employee.EmployeeId)",
    "Join Path: Track popularity analysis: Track → InvoiceLine → COUNT(InvoiceLine.InvoiceLineId)",
    "Join Path: Customer purchase by genre: Customer → Invoice → InvoiceLine → Track → Genre",
    "Join Path: Employee performance: Employee → Customer → Invoice → SUM(Invoice.Total)",

    # Business Context
    "Context: Customers can purchase individual tracks rather than complete albums",
    "Context: MediaType indicates format (MPEG, AAC, etc.) which may affect pricing",
    "Context: Some tracks appear on multiple albums (compilations, greatest hits, etc.)",
    "Context: SupportRepId indicates which employee provides customer support/sales to a customer",

    # Common Metrics Calculations
    "Metric: Top Selling Tracks = COUNT(InvoiceLine.InvoiceLineId) grouped by Track.TrackId order by count DESC",
    "Metric: Customer Segmentation by Genre = For each Customer, find the Genre with MAX(purchase count)",
    "Metric: Employee Territory Performance = SUM(Invoice.Total) grouped by Employee.EmployeeId, Customer.Country",
    "Metric: Album Completion Rate = For each customer and album, (Distinct tracks purchased from album) / (Total tracks in album)",

    # Advanced Analytics
    "Analytics: Customer Genre Affinity = % of purchases in each genre per customer compared to overall distribution",
    "Analytics: Bundle Recommendations = Tracks frequently purchased together but not in same album",
    "Analytics: Customer Churn Risk = Customers with declining purchase frequency or increasing time between purchases",
    "Analytics: Price Sensitivity = Change in purchase behavior following price changes",

    # Query Cautions
    "Caution: When calculating average revenue per customer, filter out customers with no purchases",
    "Caution: Track purchase counts may need normalization by time period for trending analysis",
    "Caution: Employee performance should account for different territory sizes and customer counts",
    "Caution: Media types may have different pricing strategies affecting revenue comparisons",

    # Temporal Analysis
    "Temporal: Recent Customer = Customer with purchase in last 30 days from analysis date",
    "Temporal: Purchase Trend = LINEAR_REGRESSION(Invoice.Total) grouped by month over trailing 12 months",
    "Temporal: Seasonal Analysis = Compare quarterly performance accounting for yearly seasonality",
    "Temporal: Employee Growth Rate = Month-over-month change in employee's total sales"
]

# Embed with sentence-transformers
model = SentenceTransformer("all-MiniLM-L6-v2")
genbi_embeddings = model.encode(genbi_chunks)

# Create FAISS index
dim = genbi_embeddings[0].shape[0]
genbi_index = faiss.IndexFlatL2(dim)
genbi_index.add(genbi_embeddings)

# Save outputs
from langchain.vectorstores import FAISS
from langchain.docstore import InMemoryDocstore
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document

# Assume `chunks` and `embeddings` already exist
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

docs = [Document(page_content=chunk) for chunk in chunks]

vectorstore = FAISS.from_documents(docs, embedding_model)
vectorstore.save_local("genbi_faiss.index")  # This creates BOTH .faiss and .pkl files

with open("genbi_chunks.json", "w") as f:
    json.dump(genbi_chunks, f, indent=2)

print("Placeholder business knowledge base saved as genbi_faiss.index and genbi_chunks.json")


Placeholder business knowledge base saved as genbi_faiss.index and genbi_chunks.json


# Module 5

In [None]:
# LangChain-based RAG Core Setup for Module 5: Retrieval Engine

from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from langchain.retrievers import EnsembleRetriever
import faiss
import json

# Initialize embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Load vectorstores
table_vectorstore = FAISS.load_local(
    folder_path="schema_faiss",
    embeddings=embedding_model,
    allow_dangerous_deserialization=True
)

column_vectorstore = FAISS.load_local(
    folder_path="schema_column_faiss",
    embeddings=embedding_model,
    allow_dangerous_deserialization=True
)

try:
    genbi_vectorstore = FAISS.load_local(
        folder_path="genbi_faiss",
        embeddings=embedding_model
    )
except Exception:
    genbi_vectorstore = None

# Define your query
user_question = biz_question

# Individually retrieve relevant docs
table_docs = table_vectorstore.as_retriever(search_kwargs={"k": 5}).get_relevant_documents(user_question)
column_docs = column_vectorstore.as_retriever(search_kwargs={"k": 5}).get_relevant_documents(user_question)
genbi_docs = genbi_vectorstore.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(user_question) if genbi_vectorstore else []

# Combine and deduplicate schema chunks
schema_chunks = list(set(doc.page_content for doc in (table_docs + column_docs)))

# Separate business chunks
business_chunks = list(set(doc.page_content for doc in genbi_docs))

# Optional: print for inspection
print("\n Schema Chunks:")
for i, chunk in enumerate(schema_chunks):
    print(f"\n--- Chunk {i+1} ---\n{chunk}")

print("\nBusiness Chunks:")
for i, chunk in enumerate(business_chunks):
    print(f"\n--- Chunk {i+1} ---\n{chunk}")





 Schema Chunks:

--- Chunk 1 ---
Table: Customer (Customer)
Column: City (City)
Type: NVARCHAR(40)
Description: City of the customer.
Primary Key: No
Business Term Tagged: No


--- Chunk 2 ---
Table: Employee (Employee)
Column: Country (Country)
Type: NVARCHAR(40)
Description: Country of the employee.
Primary Key: No
Business Term Tagged: No


--- Chunk 3 ---
Table: Media Type (MediaType)
Description: Contains information related to Media Type.
Columns:
 - Media Type Identifier (MediaTypeId): Media Type Identifier of the MediaType.
 - Name (Name): Name describing the media format.


--- Chunk 4 ---
Table: Invoice Line (InvoiceLine)
Description: Contains information related to Invoice Line.
Columns:
 - Invoice Line Identifier (InvoiceLineId): Invoice Line Identifier of the InvoiceLine.
 - Invoice Identifier (InvoiceId): Invoice Identifier of the InvoiceLine.
 - TrackId (TrackId): TrackId related to the invoice.
 - UnitPrice (UnitPrice): UnitPrice related to the invoice.
 - Invoice Line

  table_docs = table_vectorstore.as_retriever(search_kwargs={"k": 5}).get_relevant_documents(user_question)


In [None]:
for chunk in schema_chunks:
    print(chunk)

Table: Customer (Customer)
Column: City (City)
Type: NVARCHAR(40)
Description: City of the customer.
Primary Key: No
Business Term Tagged: No

Table: Employee (Employee)
Column: Country (Country)
Type: NVARCHAR(40)
Description: Country of the employee.
Primary Key: No
Business Term Tagged: No

Table: Media Type (MediaType)
Description: Contains information related to Media Type.
Columns:
 - Media Type Identifier (MediaTypeId): Media Type Identifier of the MediaType.
 - Name (Name): Name describing the media format.

Table: Invoice Line (InvoiceLine)
Description: Contains information related to Invoice Line.
Columns:
 - Invoice Line Identifier (InvoiceLineId): Invoice Line Identifier of the InvoiceLine.
 - Invoice Identifier (InvoiceId): Invoice Identifier of the InvoiceLine.
 - TrackId (TrackId): TrackId related to the invoice.
 - UnitPrice (UnitPrice): UnitPrice related to the invoice.
 - Invoice Line Quantity (Quantity): Invoice Line Quantity of the InvoiceLine.
Relationships:
 - Tra

# Module 6


In [None]:
# Run this to set your API key securely
import os
os.environ["GROQ_API_KEY"] = "Put your own groq api key here"


In [None]:
def get_system_prompt():
    return "You are a helpful assistant that generates SQL queries for a database using schema and business logic."

def get_sql_generation_prompt(user_question, schema_context, business_rules_context, examples=None):
    example_block = ""
    if examples:
        for q, sql in examples:
            example_block += f"Example Question: {q}\nExample SQL:\n{sql.strip()}\n\n"

    return f"""
You are a helpful assistant that generates SQL queries based on a user's question, the database schema, and relevant business rules.

Respond in this format:

Question:
{user_question}

Schema:
{schema_context}


Business Rules:
{business_rules_context}


Guidelines:
- Use only the tables and columns provided in the schema context.
- - Use the exact table and column names as shown in the schema context.
- Join tables only when there is a foreign key relationship.
- Apply business rules exactly as described when calculating KPIs or metrics.
- Use clear, aliased column names suitable for visualization.
- Avoid guessing any data model structures not present in the context.
- First, return the SQL code.
- Then, explain what the query does in plain English.
- If foreign key relationships are given in the schema, use them when writing JOIN clauses. Do not guess or write JOIN 1=1.
- Use only the exact column names shown in parentheses in the schema (e.g., "BillingCity"), not the business names.
- This is a SQLite database.
- Do not use SQL Server-specific functions like GETDATE(), DATEDIFF(), or YEAR().
- Use SQLite-compatible date functions such as julianday(), date(), and CURRENT_DATE.
- Use julianday() and CURRENT_DATE for date operations if needed.
- Always use InvoiceLine for revenue-related calculations
- For any revenue-related question, ALWAYS use InvoiceLine.UnitPrice * InvoiceLine.Quantity.
- Never use Track.UnitPrice unless the user specifically asks for "list price" or "track price."
- To calculate total revenue by artist, you must join: Artist → Album → Track → InvoiceLine.
- Do not invent column names. Use only those shown in the schema
- Whenever revenue needs to be calculated, ALWAYS use InvoiceLine table
-If the user asks for a percentage of a total (e.g. revenue by group), calculate the group sum, divide it by the total sum, and multiply by 100.
- Use a subquery to compute total revenue if needed.
- If a column error is detected, always refer to the exact column names under 'Column:' from the schema chunk. For example, if the schema lists 'Name' and not 'Title', use 'Name'.
- CustomerId is not present in InvoiceLine.
- To get CustomerId for a track purchase, go through: Customer → Invoice → InvoiceLine.
- If the user asks for a distinct count, use COUNT(DISTINCT ...), not COUNT(*).
- If the user asks for the number of unique items (e.g., customers, products, genres), use COUNT(DISTINCT ...).
- Only use COUNT(*) when the table is guaranteed to contain one row per entity (e.g., CustomerId is the primary key).
- Do not assume uniqueness — prefer COUNT(DISTINCT column) unless the user asks for all rows.
- Use aliases like AS TotalCustomers for clarity when returning counts.
- Always respond with a single, final SQL query that answers the question.
- Do not include multiple speculative queries or fallback attempts.
- Avoid repeating multiple alternative SQL blocks — include only the correct, final version.
- Do not assume table names, if its present please use it.
- Always respond with single query that answers the question.
- Do not include multiple speculative queries or fallback attempts.
- When returning a result with a calculated metric (e.g., revenue, quantity, growth rate), clearly include the appropriate unit in the explanation.
- Use context-aware units like dollars, tracks sold, customers, or percentages.
- If using aliases in the SQL (e.g., TotalRevenue), explain what the number represents and its unit.
- Please give only one Sqllite query not multiple
- Note: The 'Customer' table does NOT have a 'Name' column. To refer to the customer's full name, use:
    Customer.FirstName || ' ' || Customer.LastName AS CustomerName
- Only reference years that are actually present in the dataset.
- Use descriptive column aliases such as AlbumCount, CustomerTotal, etc.
- Avoid using 'COUNT' or 'SUM' as alias names.

You can dynamically retrieve available years using:
SELECT DISTINCT STRFTIME('%Y', InvoiceDate) FROM Invoice ORDER BY 1 DESC;

Use only those years for comparisons, filters, and growth rate calculations.






Respond in this format:
Question:
{user_question}

SQL:
<Write the SQL query that answers the question>

Explanation:
<Briefly explain what the SQL query does in business terms>

Chart:
ChartType: <bar | line | pie | scatter>
X: <column_name used for x-axis>
Y: <column_name used for y-axis>

Example:
ChartType: bar
X: Genre
Y: TotalRevenue



""".strip()



In [None]:
#Module 8 - Self-healing prompt builder

def get_self_healing_prompt(user_question, original_sql, error_message, schema_context=""):
    return f"""
You are an expert SQL assistant. A SQL query was generated for the question below but it caused an error during execution.

Question:
{user_question}

Original SQL:
{original_sql}

SQL Error Message:
{error_message}

Schema:
{schema_context}

Your task:
- Analyze the original SQL and the error.
- Provide a brief explanation of the fix.
- Suggest a chart to visualize the result if applicable.
- This is a SQLite database.
- Regenerate a corrected SQL query using **only column names shown in the schema**. Do not assume a column like 'Name' exists unless it is listed.
- When referencing customers, use 'FirstName' and 'LastName' instead of 'Name' if that matches the schema.
- Do not use SQL Server-specific functions like GETDATE(), DATEDIFF(), or YEAR().
- Use SQLite-compatible date functions such as julianday(), date(), and CURRENT_DATE.
- This is a SQLite database.
- Use julianday() and CURRENT_DATE for date operations if needed.
- If the error message mentions a column that doesn't exist, double-check the schema chunk above for the correct name and replace it. Do not repeat invalid column names in the fixed query.
- If a column error is detected, always refer to the exact column names under 'Column:' from the schema chunk. For example, if the schema lists 'Name' and not 'Title', use 'Name'.
- Always use InvoiceLine for revenue-related calculations
- For any revenue-related question, ALWAYS use InvoiceLine.UnitPrice * InvoiceLine.Quantity.
- Never use Track.UnitPrice unless the user specifically asks for "list price" or "track price."
- To calculate total revenue by artist, you must join: Artist → Album → Track → InvoiceLine.
- Do not invent column names. Use only those shown in the schema
- Whenever revenue needs to be calculated, ALWAYS use InvoiceLine table
- If the user asks for a percentage of a total (e.g. revenue by group), calculate the group sum, divide it by the total sum, and multiply by 100.
- Use a subquery to compute total revenue if needed.
- Use a CTE( common table expression) to compute total revenue if needed.
- Use a CTE( common table expression) to break down the query and arrive at the final result.
- CustomerId is not present in InvoiceLine.
- To get CustomerId for a track purchase, go through: Customer → Invoice → InvoiceLine.
- Use only the exact column names provided in the schema (e.g., ArtistId, AlbumId).
- Do not invent field names like ArtistIdentifier or AlbumIdentifier.
- Always validate the join path: Album → Track → InvoiceLine → Invoice → Customer
- If the user asks for a distinct count, use COUNT(DISTINCT ...), not COUNT(*).
- Only use COUNT(*) when the table has one row per entity (e.g., CustomerId is the PK).
- If the user asks for a distinct count, use COUNT(DISTINCT ...), not COUNT(*).
- If the user asks for the number of unique items (e.g., customers, products, genres), use COUNT(DISTINCT ...).
- Only use COUNT(*) when the table is guaranteed to contain one row per entity (e.g., CustomerId is the primary key).
- Do not assume uniqueness — prefer COUNT(DISTINCT column) unless the user asks for all rows.
- Use aliases like AS TotalCustomers for clarity when returning counts.
- When returning a result with a calculated metric (e.g., revenue, quantity, growth rate), clearly include the appropriate unit in the explanation.
- Use context-aware units like dollars, tracks sold, customers, or percentages.
- If using aliases in the SQL (e.g., TotalRevenue), explain what the number represents and its unit.
- Apply SQLite-safe practices when rewriting window functions.
- Analyze the original SQL and the error.
- Regenerate a corrected SQL query using only the tables and columns shown in the schema.
- **If referring to customer names, use 'FirstName' and 'LastName'. Never use 'Name' unless it's shown in the schema.**
Note: The 'Customer' table does NOT have a 'Name' column. To refer to the customer's full name, use:
    Customer.FirstName || ' ' || Customer.LastName AS CustomerName
- Avoid naming SQL columns as 'COUNT' or 'SUM'. Use descriptive names like 'AlbumCount' or 'TotalPurchases'.
- Instead, use descriptive names like 'AlbumCount', 'TotalRevenue', etc.
-Only reference years that are actually present in the dataset.

-You can dynamically retrieve available years using:
SELECT DISTINCT STRFTIME('%Y', InvoiceDate) FROM Invoice ORDER BY 1 DESC;

-Use only those years for comparisons, filters, and growth rate calculations.


Important rules for fixing LAG/LEAD functions:
- DO NOT repeat LAG() or LEAD() multiple times in an arithmetic expression.
- First, materialize them in a CTE or subquery. Then calculate the growth or difference using aliases.
- Example:
    -- BAD:
    SELECT (revenue - LAG(revenue)) / LAG(revenue) FROM table;

    -- GOOD:
    WITH base AS (
        SELECT year, revenue, LAG(revenue) OVER (ORDER BY year) AS prev_revenue
        FROM ...
    )
    SELECT year, revenue, (revenue - prev_revenue) / prev_revenue FROM base;



Respond in this format:

SQL:
<Corrected SQL>

Explanation:
<Brief fix explanation>

Chart:
ChartType: <bar | line | pie | scatter>
X: <column_name used for x-axis>
Y: <column_name used for y-axis>

Example:
ChartType: bar
X: Genre
Y: TotalRevenue

""".strip()


In [None]:
import os
from groq import Groq

# Setup Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

# Example retrieved data
user_question = biz_question
schema_chunks = list(set(doc.page_content for doc in (table_docs + column_docs)))
business_chunks = list(set(doc.page_content for doc in genbi_docs)) if genbi_vectorstore else []
example_qa_pairs = [
    ("List all products and their prices", "SELECT Name, ListPrice FROM Production.Product;"),
    ("How many orders per customer?", "SELECT CustomerID, COUNT(*) FROM Sales.SalesOrderHeader GROUP BY CustomerID;")
]

# Build schema context with appended instruction
schema_context = "\n".join(schema_chunks)

# Build the prompt
system_prompt = get_system_prompt()
user_prompt = get_sql_generation_prompt(
    user_question=user_question,
    schema_context= schema_context,
    business_rules_context="\n".join(business_chunks),
    examples=example_qa_pairs
)

# Groq model call
chat_completion = client.chat.completions.create(
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ],
    model="llama-3.3-70b-versatile"
)

# Output
print("Groq Response:")
print(chat_completion.choices[0].message.content)


Groq Response:
Question:
Which country has the most customers?

SQL:
```sql
SELECT 
    Country, 
    COUNT(DISTINCT CustomerId) AS TotalCustomers
FROM 
    Customer
GROUP BY 
    Country
ORDER BY 
    TotalCustomers DESC
LIMIT 1;
```

Explanation:
This SQL query identifies the country with the most customers by grouping the Customer table by the Country column and counting the distinct CustomerId in each group. The country with the highest count of distinct customers is then returned as the result. This query essentially determines the country where the greatest number of unique customers reside.

Chart:
ChartType: bar
X: Country
Y: TotalCustomers


# Module 7

In [None]:
import re
import sqlite3
import json

# --- STEP 1: Extract structured sections from the LLM response ---


def extract_cte_names(sql_text):
    # This finds names like: WITH genreRevenue AS (
    pattern = re.compile(r"with\s+([a-zA-Z0-9_]+)\s+as\s*\(", re.IGNORECASE)
    return set(pattern.findall(sql_text))


def parse_llm_response(response_text):
    sections = {
        "sql": None,
        "explanation": None,
        "chart": None,
        "chart_details": {
            "chart_type": None,
            "x": None,
            "y": None
        }
    }

    # Extract SQL block with or without "SQL:" label
    sql_match = re.search(r"```sql(.*?)```", response_text, re.DOTALL | re.IGNORECASE)
    if not sql_match:
        sql_match = re.search(r"(SELECT[\s\S]+?;)", response_text, re.IGNORECASE)

    sections["sql"] = sql_match.group(1).strip() if sql_match else None

    # Extract Explanation
    explanation_match = re.search(r"(?i)Explanation:\s*(.*?)(?:\nChart:|$)", response_text, re.DOTALL)
    sections["explanation"] = explanation_match.group(1).strip() if explanation_match else None

    # Extract Chart suggestion
    chart_match = re.search(r"(?i)Chart:\s*((?:.|\n)*)", response_text)
    if chart_match:
        chart_block = chart_match.group(1).strip()
        sections["chart"] = chart_block

        # Parse structured chart details
        chart_type_match = re.search(r"ChartType:\s*(\w+)", chart_block, re.IGNORECASE)
        x_match = re.search(r"X:\s*([a-zA-Z_][a-zA-Z0-9_]*)", chart_block, re.IGNORECASE)
        y_match = re.search(r"Y:\s*([a-zA-Z0-9_,\s]+)", chart_block, re.IGNORECASE)

        sections["chart_details"]["chart_type"] = chart_type_match.group(1).strip() if chart_type_match else None
        sections["chart_details"]["x"] = x_match.group(1).strip() if x_match else None
        sections["chart_details"]["y"] = y_match.group(1).strip() if y_match else None

    return sections


# --- STEP 2: SQL syntax check using sqlite3 ---

def check_sql_syntax(sql_query):
    try:
        conn = sqlite3.connect(db_path)
        conn.execute("EXPLAIN QUERY PLAN " + sql_query)
        return True, None
    except sqlite3.Error as e:
        return False, str(e)
    finally:
        conn.close()


# --- STEP 3: Reference check against JSON schema ---

def validate_sql_references(sql_query, schema_json_path="/content/schema_metadata_enriched_with_tags.json"):
    with open(schema_json_path, "r") as f:
        schema = json.load(f)

    valid_tables = {
        t["table_name"].lower(): {col["column_name"].lower() for col in t["columns"]}
        for t in schema
    }

    found_tables = set()
    found_columns = set()

    tokens = re.findall(r"\b\w+\b", sql_query.lower())
    for token in tokens:
        if token in valid_tables:
            found_tables.add(token)
        else:
            for table, columns in valid_tables.items():
                if token in columns:
                    found_columns.add(token)

    # Extract defined CTEs from the SQL query
    defined_ctes = {cte.lower() for cte in extract_cte_names(sql_query)}

    missing_tables = set()
    missing_columns = set()
    defined_ctes = extract_cte_names(sql_query)
    print("Detected CTEs:", defined_ctes)

    # Validate physical tables only (exclude CTEs)
    from_join_tables = [next(filter(None, match)) for match in re.findall(r"from\s+(\w+)|join\s+(\w+)", sql_query.lower())]
    for table in from_join_tables:
      if table and table.lower() not in valid_tables and table.lower() not in defined_ctes:
        missing_tables.add(table)


    for col in found_columns:
        col_in_schema = any(col in cols for cols in valid_tables.values())
        if not col_in_schema:
            missing_columns.add(col)

    is_valid = not missing_tables and not missing_columns
    return is_valid, missing_tables, missing_columns



In [None]:
# 1. Extract LLM output text from Module 6 response
llm_response = chat_completion.choices[0].message.content

# 2. Feed it directly into your Module 7 validator
parsed = parse_llm_response(llm_response)
chart_type = parsed["chart_details"]["chart_type"]
x_col = parsed["chart_details"]["x"]
y_raw = parsed["chart_details"]["y"]
y_col = [col.strip() for col in y_raw.split(",")] if y_raw else None


# 3. Validate syntax
if parsed["sql"]:
    syntax_ok, syntax_error = check_sql_syntax(parsed["sql"])
    ref_ok, missing_tables, missing_columns = validate_sql_references(parsed["sql"])

    print("Syntax OK" if syntax_ok else f"Syntax Error: {syntax_error}")
    print("Schema Check Passed" if ref_ok else f"Missing: Tables={missing_tables}, Columns={missing_columns}")

else:
    print(" No SQL block found in LLM response.")

# Print explanation and chart regardless
print("Missing Tables (final):", missing_tables)
print("Final is_valid:", not missing_tables and not missing_columns)

print("\nExplanation:\n", parsed["explanation"] or "No explanation found.")




Detected CTEs: set()
Syntax OK
Schema Check Passed
Missing Tables (final): set()
Final is_valid: True

Explanation:
 This SQL query identifies the country with the most customers by grouping the Customer table by the Country column and counting the distinct CustomerId in each group. The country with the highest count of distinct customers is then returned as the result. This query essentially determines the country where the greatest number of unique customers reside.


In [None]:
import sqlite3
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
print([row[0] for row in cursor.fetchall()])
conn.close()

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


# MOdule 8

In [None]:
import sqlite3

def execute_sql_with_self_heal(original_sql, user_question, schema_chunks, groq_client, model="llama-3.3-70b-versatile"):
    try:
        # Try executing the original SQL first
        conn = sqlite3.connect(db_path)
        conn.execute("EXPLAIN QUERY PLAN " + original_sql)
        conn.close()

        print("SQL executed successfully. No healing needed.")
        return {
            "final_sql": original_sql,
            "was_healed": False,
            "error_before_healing": None,
            "explanation": None,
            "chart": None,
            "status": "original_executed_successfully"
        }

    except sqlite3.Error as e:
        error_message = str(e)
        print(f"Execution error: {error_message}")
        print("Triggering LLM-based self-healing...")

        # Build retry prompt
        schema_context = "\n".join(schema_chunks)
        retry_prompt = get_self_healing_prompt(
            user_question=user_question,
            original_sql=original_sql,
            error_message=error_message,
            schema_context=schema_context
        )

        # Send to Groq
        retry_response = groq_client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are an expert SQL assistant that fixes broken queries."},
                {"role": "user", "content": retry_prompt}
            ],
            model=model
        )

        retry_text = retry_response.choices[0].message.content
        print("LLM Retry Output:\n", retry_text)

        #  Parse output
        parsed_retry = parse_llm_response(retry_text)
        if parsed_retry["sql"]:
            syntax_ok, _ = check_sql_syntax(parsed_retry["sql"])
            ref_ok, _, _ = validate_sql_references(parsed_retry["sql"])

            if syntax_ok and ref_ok:
                print(" Healed SQL passed validation.")
                return {
                    "final_sql": parsed_retry["sql"],
                    "was_healed": True,
                    "error_before_healing": error_message,
                    "explanation": parsed_retry.get("explanation"),
                    "chart": parsed_retry.get("chart"),
                    "status": "healed"
                }

        print("Healing failed.")
        return {
            "final_sql": None,
            "was_healed": True,
            "error_before_healing": error_message,
            "explanation": None,
            "chart": None,
            "status": "healing_failed"
        }


In [None]:
result = execute_sql_with_self_heal(
    original_sql=parsed["sql"],
    user_question=biz_question,
    schema_chunks=schema_chunks,
    groq_client=client
)

print("Final SQL:", result["final_sql"])
print("Was Healed:", result["was_healed"])
print("Error Before Healing:", result["error_before_healing"])
print("Status:", result["status"])


SQL executed successfully. No healing needed.
Final SQL: SELECT 
    Country, 
    COUNT(DISTINCT CustomerId) AS TotalCustomers
FROM 
    Customer
GROUP BY 
    Country
ORDER BY 
    TotalCustomers DESC
LIMIT 1;
Was Healed: False
Error Before Healing: None
Status: original_executed_successfully


Combining both module 7 and 8

In [None]:
def validate_then_execute_with_optional_heal(sql_text, user_question, schema_chunks, groq_client, db_path, model="llama-3.3-70b-versatile"):
    print("Validating SQL in Module 7:")
    print(sql_text)

    syntax_ok, syntax_error = check_sql_syntax(sql_text)
    ref_ok, missing_tables, missing_columns = validate_sql_references(sql_text)

    print("Syntax OK:" if syntax_ok else f"Syntax error: {syntax_error}")
    print("Schema OK:" if ref_ok else f"Missing: Tables={missing_tables}, Columns={missing_columns}")

    if syntax_ok:
        try:
            # Try executing original SQL
            conn = sqlite3.connect(db_path)
            conn.execute("EXPLAIN QUERY PLAN " + sql_text)
            conn.close()

            return {
                "original_sql": sql_text,
                "healed_sql": None,
                "final_sql": sql_text,
                "was_healed": False,
                "error_before_healing": None,
                "status": "original_executed_with_warning" if not ref_ok else "original_executed_successfully"
            }

        except Exception as e:
            error_msg = str(e)
            print("Execution error (triggering healing):", error_msg)

            schema_context = "\n".join(schema_chunks)
            healing_result = execute_sql_with_self_heal(
                original_sql=sql_text,
                user_question=user_question,
                schema_context=schema_context,
                groq_client=groq_client,
                model=model,
                db_path=db_path
            )

            healed_sql = healing_result.get("final_sql")
            retry_output = healing_result.get("llm_response", {})

            if healed_sql:
                if healed_sql.strip() != sql_text.strip():
                    print("Healing Success: SQL was modified")
                    print("Original SQL:\n", sql_text)
                    print("Healed SQL:\n", healed_sql)

                return {
                    "original_sql": sql_text,
                    "healed_sql": healed_sql,
                    "final_sql": healed_sql,
                    "was_healed": True,
                    "error_before_healing": error_msg,
                    "status": "healed"
                }

            return {
                "original_sql": sql_text,
                "healed_sql": None,
                "final_sql": None,
                "was_healed": True,
                "error_before_healing": error_msg,
                "status": "healing_failed"
            }

    # Fails syntax — block execution
    return {
        "original_sql": sql_text,
        "healed_sql": None,
        "final_sql": None,
        "was_healed": False,
        "error_before_healing": syntax_error or "Unknown syntax error",
        "status": "module_7_validation_failed"
    }


In [None]:
"""import json
with open("/content/schema_metadata_enriched_with_tags.json") as f:
    schema = json.load(f)

table_names = [t["table_name"] for t in schema]
print("Tables extracted from schema:\n", table_names)"""

'import json\nwith open("/content/schema_metadata_enriched_with_tags.json") as f:\n    schema = json.load(f)\n\ntable_names = [t["table_name"] for t in schema]\nprint("Tables extracted from schema:\n", table_names)'

In [None]:
"""for table in schema:
    print(f"\nTable: {table['table_name']}")
    print("Columns:", [col["column_name"] for col in table["columns"]])"""

'for table in schema:\n    print(f"\nTable: {table[\'table_name\']}")\n    print("Columns:", [col["column_name"] for col in table["columns"]])'

Module 9 Plan with Knowledge Base Placeholder
We’ll:

Retrieve the business chunks if the knowledge base (genbi_vectorstore) is available

Else, just pass an empty list or placeholder string

Prompt the LLM to generate a clear, business-friendly insight from the SQL + original question

In [None]:
def generate_sql_insight(user_question, final_sql, genbi_vectorstore=None, groq_client=None, model="llama-3.3-70b-versatile"):
    # Step 1: Try to retrieve relevant business rules (if knowledge base is enabled)
    if genbi_vectorstore:
        genbi_docs = genbi_vectorstore.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(user_question)
        business_chunks = list(set(doc.page_content for doc in genbi_docs))
    else:
        business_chunks = ["<Business rules are not currently available.>"]

    # Step 2: Build the explanation prompt
    prompt = f"""
You are a helpful business assistant.

User Question:
{user_question}

SQL Query:
{final_sql}

Business Rules:
{chr(10).join(business_chunks)}

Your tasks:
- Summarize the kind of insight or KPI this query would produce
- Suggest 1–2 follow-up business questions based on this data

Respond in this format:

Explanation:
<Your plain-language explanation>

Insight Summary:
<What insight or KPI this generates>

Follow-up Questions:
- ...
- ...
""".strip()

    # Step 3: Call Groq LLM
    response = groq_client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are an expert in translating SQL queries into business insights."},
            {"role": "user", "content": prompt}
        ],
        model=model
    )

    # Step 4: Return LLM Output
    return response.choices[0].message.content


In [None]:
insight_text = generate_sql_insight(
    user_question=biz_question,
    final_sql=result["final_sql"],  # from Module 8
    genbi_vectorstore=genbi_vectorstore,  # or None
    groq_client=client
)

print("Business Insight:\n")
print(insight_text)


Business Insight:

Explanation:
This SQL query is designed to identify the country with the highest number of unique customers. It does this by selecting the 'Country' field and counting the distinct 'CustomerId' for each country in the 'Customer' table. The results are then grouped by country and ordered in descending order based on the total number of customers, with the top result being the country with the most customers.

Insight Summary:
This query generates the insight of which country has the largest customer base, providing a key performance indicator (KPI) for geographical customer distribution.

Follow-up Questions:
- What is the average order value or total revenue generated from customers in this country, to understand the financial impact of this customer base?
- How does the customer growth rate in this country compare to other regions, to identify opportunities for expansion or areas that may require more marketing efforts?


Module 10

In [None]:
#def extract_chart_type_from_text(text):
    #if not text:
    #    return "auto"
    ## if "bar" in text:
       # return "bar"
   # elif "line" in text or "trend" in text:
       # return "line"
   # elif "pie" in text:
        #return "pie"
    #elif "scatter" in text:
        #return "scatter"
    #else:
        #return "auto"


In [None]:
def extract_chart_info(text):
    chart_type = "auto"
    x_col = None
    y_col = None

    if isinstance(text, str):
        text = text.lower()

        # ChartType
        if "chartType:" in text:
            match = re.search(r"chartType:\s*(\w+)", text)
            if match:
                chart_type = match.group(1)

        # X axis
        x_match = re.search(r"\b[x]:\s*([a-zA-Z_][a-zA-Z0-9_]*)", text)
        if x_match:
            x_col = x_match.group(1)

        # Y axis (support for multiple values)
        y_match = re.search(r"\b[y]:\s*([a-zA-Z0-9_,\s]+)", text)
        if y_match:y_raw = y_match.group(1)
        y_col = [y.strip() for y in y_raw.split(",") if y.strip()]
    return chart_type, x_col, y_col


In [None]:
import sqlite3
import pandas as pd
import plotly.express as px

def generate_chart_from_sql(
    final_sql,
    db_path="/content/Chinook_Sqlite.sqlite",
    chart_type="auto",
    chart_title=None,
    x_col=None,
    y_col=None
):
    # Step 1: Run SQL
    try:
        conn = sqlite3.connect(db_path)
        df = pd.read_sql_query(final_sql, conn)
        conn.close()
        df.columns = [col.strip().split("\n")[0] for col in df.columns]  # Clean bad LLM aliases
        # Clean x_col and y_col early
        if isinstance(x_col, str):
            x_col = x_col.strip().split("\n")[0]

        if isinstance(y_col, list):
            y_col = [col.strip().split("\n")[0] for col in y_col]
        elif isinstance(y_col, str):
            y_col = [y_col.strip().split("\n")[0]]
    except Exception as e:
        print(f"Failed to execute SQL: {e}")
        return None, None

    # Step 2: Show DataFrame
    print("Data Preview:")
    display(df)

    if df.empty or df.shape[1] < 2:
        print("Not enough data to generate chart.")
        return df, None

    # Step 3: Detect chart type if not set
    if chart_type == "auto":
        if df.shape[1] == 2:
            chart_type = "bar"
        elif "date" in df.columns[0].lower() or "year" in df.columns[0].lower():
            chart_type = "line"
        else:
            chart_type = "bar"

    # Step 4: Column selection (use LLM override or auto-detect)
    numeric_cols = df.select_dtypes(include='number').columns.tolist()
    categorical_cols = df.select_dtypes(exclude='number').columns.tolist()

    # Safely drop NaNs from only valid numeric fields
    if numeric_cols:
      df.dropna(subset=numeric_cols, inplace=True)

    if not x_col:
        for col in categorical_cols:
            if col.lower() not in ["id", "index", "rowid"]:
                x_col = col
                break

    if isinstance(y_col, str):
        y_col_final = [y_col]
    elif isinstance(y_col, list):
        y_col_final = y_col
    else:
        y_col_final = numeric_cols[:1]

    # Filter out any y_col that doesn't exist in the DataFrame
    y_col_final = [col for col in y_col_final if col in df.columns]

# Optional warning
    if not y_col_final:
      print("No valid y_col found in DataFrame columns. Chart may not render.")

    # Step 4.5: Add contextual unit labels to chart title
    title_suffix = ""
    if any("revenue" in col.lower() or "amount" in col.lower() for col in y_col_final):
        title_suffix = " (in USD)"
    elif any("growth" in col.lower() or "percent" in col.lower() for col in y_col_final):
        title_suffix = " (%)"
    elif any("quantity" in col.lower() or "count" in col.lower() or "tracks" in col.lower() for col in y_col_final):
        title_suffix = " (tracks)"
    elif any("customer" in col.lower() for col in y_col_final):
        title_suffix = " (customers)"

    full_title = (chart_title or f"{', '.join(y_col_final)} by {x_col}") + title_suffix

    # Step 4.9: Drop rows with missing values in x or y
    columns_to_check = [x_col] + y_col_final if x_col and y_col_final else []
    df = df.dropna(subset=columns_to_check, how='all')
    # Step 5: Generate chart
    fig = None
    try:
        if chart_type == "bar":
            fig = px.bar(df, x=x_col, y=y_col_final, title=full_title)
        elif chart_type == "line":
            fig = px.line(df, x=x_col, y=y_col_final, title=full_title)
        elif chart_type == "pie":
            if len(y_col_final) == 1:
                fig = px.pie(df, names=x_col, values=y_col_final[0], title=full_title)
            else:
                print("Pie chart requires exactly one numeric column.")
        elif chart_type == "scatter":
            fig = px.scatter(df, x=x_col, y=y_col_final[0], title=full_title)
        else:
            print(f"Unsupported chart type: {chart_type}")
    except Exception as e:
        print(f"Chart generation error: {e}")

    # Step 6: Show chart
    if fig:
        fig.show()

    return df, fig


In [None]:
print("ChartType:", chart_type)
print("X col:", x_col)
print("Y col:", y_col)

ChartType: bar
X col: Country
Y col: ['TotalCustomers']


In [None]:
df, fig = generate_chart_from_sql(
    final_sql=parsed["sql"],
    db_path="/content/Chinook_Sqlite.sqlite",  # or your variable
    chart_type=chart_type,
    chart_title=parsed.get("question", "Query Result"),
    x_col=x_col,
    y_col=y_col
)



Data Preview:


Unnamed: 0,Country,TotalCustomers
0,USA,13


Full Pipeline

In [None]:
def run_fresh_pipeline(user_question, groq_client, table_vectorstore, column_vectorstore, genbi_vectorstore=None, db_path="/content/Chinook_Sqlite.sqlite", model="llama-3.3-70b-versatile"):
    # STEP 1: RAG – retrieve schema and business chunks
    table_docs = table_vectorstore.as_retriever(search_kwargs={"k": 6}).get_relevant_documents(user_question)
    column_docs = column_vectorstore.as_retriever(search_kwargs={"k": 6}).get_relevant_documents(user_question)
    schema_chunks = list(set(doc.page_content for doc in (table_docs + column_docs)))

    if genbi_vectorstore:
        genbi_docs = genbi_vectorstore.as_retriever(search_kwargs={"k": 3}).get_relevant_documents(user_question)
        business_chunks = list(set(doc.page_content for doc in genbi_docs))
    else:
        business_chunks = ["<Business knowledge base not loaded>"]

    # STEP 2: Prompt & SQL generation (Module 6)
    schema_context = "\n".join(schema_chunks)

    user_prompt = get_sql_generation_prompt(
        user_question=user_question,
        schema_context=schema_context,
        business_rules_context="\n".join(business_chunks),
        examples=[]
    )

    response = groq_client.chat.completions.create(
        messages=[
            {"role": "system", "content": get_system_prompt()},
            {"role": "user", "content": user_prompt}
        ],
        model=model
    )

    llm_response = response.choices[0].message.content
    parsed = parse_llm_response(llm_response)

    # STEP 3: SQL validation + healing (Modules 7 + 8)
    # STEP 3: SQL validation + healing (Modules 7 + 8)
    validation_result = validate_then_execute_with_optional_heal(
        sql_text=parsed["sql"],
        user_question=user_question,
        schema_chunks=schema_chunks,
        groq_client=groq_client,
        db_path=db_path,
        model=model
    )

    # Smart final SQL selection (prefer healed only if safe)
    final_sql = (
        validation_result.get("healed_sql")
        if validation_result.get("was_healed")
        else validation_result.get("original_sql")
    )

    # Optional: override if healed SQL contains risky logic like default fallback in LAG
    if "LAG" in final_sql and ", 0)" in final_sql:
        print("Healed SQL has default fallback in LAG — reverting to original LLM SQL.")
        final_sql = parsed["sql"]

    # If no valid SQL: exit early
    if not final_sql:
     return {
        "question": user_question,
        "final_sql": None,
        "was_healed": validation_result["was_healed"],
        "syntax_check_passed": False,
        "schema_check_passed": False,
        "explanation": None,
        "insight": None,
        "dataframe": None,
        "chart": None,
        "chart_details": {
            "chart_type": None,
            "x": None,
            "y": None
        },
        "status": validation_result["status"]
    }


    # STEP 4: Generate insight (Module 9)
    insight_text = generate_sql_insight(
        user_question=user_question,
        final_sql=final_sql,
        genbi_vectorstore=genbi_vectorstore,
        groq_client=groq_client
    )

    # STEP 5: Visualize (Module 10)
# STEP 5: Visualize (Module 10) using Module 10 chart logic only

    chart_details = parsed.get("chart_details", {})
    chart_type = chart_details.get("chart_type", "auto")
    x_col = chart_details.get("x")
    y_raw = chart_details.get("y")

# Ensure y_col is always a list
    y_col = [y.strip() for y in y_raw.split(",")] if y_raw else None

    print("Final SQL to Execute:\n", final_sql)
    print("Chart Details:", chart_details)

# Optional: test df separately before chart
    df_test = pd.read_sql_query(final_sql, sqlite3.connect(db_path))
    print("Raw SQL Result Preview:")
    print(df_test.head())
    print(" Shape:", df_test.shape)


# Generate chart using unified Module 10 function
    df, fig = generate_chart_from_sql(
    final_sql=final_sql,  # from validation_result
    db_path=db_path,
    chart_type=chart_type,
    chart_title=user_question,
    x_col=x_col,
    y_col=y_col
)

    # Optional: Display df if chart is None but we still got useful data
    if df is not None and not df.empty:
        print("Final DataFrame Preview:")
        display(df)
    else:
        print(" No usable data for chart or table.")



    # STEP 6: Return final results
    return {
        "question": user_question,
        "final_sql": final_sql,
        "was_healed": validation_result["was_healed"],
        "syntax_check_passed": True,
        "schema_check_passed": True,
        "explanation": parsed.get("explanation"),
        "insight": insight_text,
        "dataframe": df,
        "chart": fig,
        "status": validation_result["status"],
        "chart_details": chart_details  # Optional: useful for debugging
    }


In [None]:
# Example usage
result = run_fresh_pipeline(
    user_question=biz_question,
    groq_client=client,
    table_vectorstore=table_vectorstore,
    column_vectorstore=column_vectorstore,
    genbi_vectorstore=genbi_vectorstore)


# Display results
print("Question:", result["question"])
print("Final SQL:", result["final_sql"])
print("Was Healed:", result["was_healed"])
print("Syntax Check Passed:", result["syntax_check_passed"])
print("Schema Check Passed:", result["schema_check_passed"])


#Show DataFrame
    #if result["dataframe"] is not None:
    #from IPython.display import display
    #print("Data Preview:")
    #display(result["dataframe"])
      #else:
    #print("No data returned.")

# Optional: chart is shown automatically via generate_chart_from_sql




Validating SQL in Module 7:
SELECT 
  Country, 
  COUNT(DISTINCT CustomerId) AS TotalCustomers
FROM 
  Customer
GROUP BY 
  Country
ORDER BY 
  TotalCustomers DESC
LIMIT 1;
Detected CTEs: set()
Syntax OK:
Schema OK:
Final SQL to Execute:
 SELECT 
  Country, 
  COUNT(DISTINCT CustomerId) AS TotalCustomers
FROM 
  Customer
GROUP BY 
  Country
ORDER BY 
  TotalCustomers DESC
LIMIT 1;
Chart Details: {'chart_type': 'bar', 'x': 'Country', 'y': 'TotalCustomers'}
Raw SQL Result Preview:
  Country  TotalCustomers
0     USA              13
 Shape: (1, 2)
Data Preview:


Unnamed: 0,Country,TotalCustomers
0,USA,13


Final DataFrame Preview:


Unnamed: 0,Country,TotalCustomers
0,USA,13


Question: Which country has the most customers?
Final SQL: SELECT 
  Country, 
  COUNT(DISTINCT CustomerId) AS TotalCustomers
FROM 
  Customer
GROUP BY 
  Country
ORDER BY 
  TotalCustomers DESC
LIMIT 1;
Was Healed: False
Syntax Check Passed: True
Schema Check Passed: True


# UI

In [None]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.0 (from gradio)
  Downloading gradio_client-1.10.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6

In [None]:
import gradio as gr


In [None]:
## Gradio interface function - using Gradio Plot component
def gradio_pipeline_ui(user_question):
    # Run the pipeline with the existing client and vector stores
    result = run_fresh_pipeline(
        user_question=user_question,
        groq_client=client,  # Use your existing client
        table_vectorstore=table_vectorstore,
        column_vectorstore=column_vectorstore,
        genbi_vectorstore=genbi_vectorstore
    )

    # Return values for Gradio interface
    return (
        result["final_sql"],
        result["explanation"],
        result["insight"],
        result["dataframe"],
        result.get("chart")  # Return the Plotly figure directly
    )

# Create the Gradio interface with Plot component instead of HTML
with gr.Blocks(title="AIDA BI Assistant") as demo:
    gr.Markdown("# AIDA BI Assistant")
    gr.Markdown("Ask questions in plain English and see SQL, insights, and charts!")

    with gr.Row():
        question = gr.Textbox(label="Ask a Business Question", value="What is the year on year revenue")
        submit_btn = gr.Button("Run Query", variant="primary")

    with gr.Tabs():
        with gr.TabItem("Generated SQL"):
            sql_output = gr.Code(language="sql", label="SQL Query")

        with gr.TabItem("Explanation"):
            explanation = gr.Textbox(label="Query Explanation", lines=5)

        with gr.TabItem("Business Insights"):
            insights = gr.Textbox(label="Business Insights", lines=10)

        with gr.TabItem("Results"):
            results = gr.Dataframe(label="Query Results")

        with gr.TabItem("Chart"):
            # Use Gradio's Plot component instead of HTML
            chart = gr.Plot(label="Visualization")

    # Connect the button to the function
    submit_btn.click(
        fn=gradio_pipeline_ui,
        inputs=[question],
        outputs=[sql_output, explanation, insights, results, chart]
    )

# Launch the interface
demo.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://69af5d73b7fc2fb006.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


