<a href="https://colab.research.google.com/github/piyush080205/Schema-Assist/blob/main/SchemaAssist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q duckdb cohere faiss-cpu plotly pandas gradio prophet kaleido

from google.colab import output
output.enable_custom_widget_manager()

In [2]:
import numpy as np
import pandas as pd
import cohere
import faiss
import duckdb
import plotly.express as px
from prophet import Prophet
from prophet.plot import plot_plotly
import gradio as gr
import warnings
import re
from google.colab import userdata, files

warnings.filterwarnings('ignore')

COHERE_API_KEY = userdata.get('COHERE_API_KEY')
co = cohere.ClientV2(api_key=COHERE_API_KEY)
con = duckdb.connect(":memory:")

print("‚úÖ Setup complete")

‚úÖ Setup complete


In [3]:
print("üì§ Upload all 9 Olist CSV files now...")
uploaded = files.upload()

table_mapping = {
    "olist_customers_dataset": "customers",
    "olist_geolocation_dataset": "geolocation",
    "olist_order_items_dataset": "order_items",
    "olist_order_payments_dataset": "order_payments",
    "olist_order_reviews_dataset": "order_reviews",
    "olist_orders_dataset": "orders",
    "olist_products_dataset": "products",
    "olist_sellers_dataset": "sellers",
    "product_category_name_translation": "category_translation"
}

for uploaded_name in uploaded.keys():
    base = uploaded_name.split(' (')[0].replace('.csv', '')
    table_name = table_mapping.get(base, base.replace('olist_', '').replace('_dataset', ''))
    try:
        con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM read_csv_auto('{uploaded_name}', header=True)")
        print(f"   ‚úÖ Loaded ‚Üí {table_name}")
    except Exception as e:
        print(f"   ‚ùå {uploaded_name}: {e}")

print("\nüéâ All tables loaded!")

üì§ Upload all 9 Olist CSV files now...


Saving olist_customers_dataset.csv to olist_customers_dataset (1).csv
Saving olist_geolocation_dataset.csv to olist_geolocation_dataset (1).csv
Saving olist_order_items_dataset.csv to olist_order_items_dataset (1).csv
Saving olist_order_payments_dataset.csv to olist_order_payments_dataset (1).csv
Saving olist_order_reviews_dataset.csv to olist_order_reviews_dataset (1).csv
Saving olist_orders_dataset.csv to olist_orders_dataset (1).csv
Saving olist_products_dataset.csv to olist_products_dataset (1).csv
Saving olist_sellers_dataset.csv to olist_sellers_dataset (1).csv
Saving product_category_name_translation.csv to product_category_name_translation (1).csv
   ‚úÖ Loaded ‚Üí customers
   ‚úÖ Loaded ‚Üí geolocation
   ‚úÖ Loaded ‚Üí order_items
   ‚úÖ Loaded ‚Üí order_payments
   ‚úÖ Loaded ‚Üí order_reviews
   ‚úÖ Loaded ‚Üí orders
   ‚úÖ Loaded ‚Üí products
   ‚úÖ Loaded ‚Üí sellers
   ‚úÖ Loaded ‚Üí category_translation

üéâ All tables loaded!


In [4]:
def extract_metadata(con):
    tables = con.sql("SHOW TABLES").df()['name'].tolist()
    metadata = {}
    for t in tables:
        cols = con.sql(f"DESCRIBE {t}").df()[['column_name', 'column_type']].to_dict('records')
        metadata[t] = {"columns": cols}
    return metadata

schema_metadata = extract_metadata(con)
print(f"‚úÖ {len(schema_metadata)} tables loaded")

documents = []
for table_name, info in schema_metadata.items():
    col_str = "\n".join([f"  - {c['column_name']} ({c['column_type']})" for c in info["columns"]])
    doc = f"Table: {table_name}\nColumns:\n{col_str}\nUse category_translation for English names."
    documents.append(doc)

embed_response = co.embed(model="embed-v4.0", texts=documents, input_type="search_document", embedding_types=["float"])
embeddings = np.array(embed_response.embeddings.float).astype('float32')
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
print("‚úÖ FAISS ready!")

‚úÖ Extracted 9 tables
‚úÖ FAISS ready!


In [5]:
def generate_sql(user_query, schema_context, chat_history=None):
    if chat_history is None: chat_history = []
    system_prompt = """You are an expert Olist analyst.
Rules:
- ALWAYS join with `category_translation` when showing product categories to show English names.
- Return ONLY valid SELECT query with LIMIT 1000.
- Use correct joins.
- For revenue use payment_value.
- Return ONLY the SQL."""

    messages = [{"role": "system", "content": system_prompt}]
    for turn in chat_history[-5:]:
        messages.append({"role": "user", "content": turn["query"]})
        messages.append({"role": "assistant", "content": turn["sql"]})
    messages.append({"role": "user", "content": f"Schema:\n{schema_context}\n\nQuestion: {user_query}"})

    resp = co.chat(model="command-a-03-2025", messages=messages, temperature=0.1)
    sql = resp.message.content[0].text.strip()
    if "```" in sql:
        sql = sql.split("```")[1].replace("sql","").strip()
    return sql

In [6]:
def run_query(sql):
    sql_upper = sql.upper().strip()
    if any(x in sql_upper for x in ["DROP","DELETE","UPDATE","INSERT","CREATE","ALTER"]):
        raise ValueError("Only SELECT allowed")
    return con.sql(sql).df()

def explain_results(df, user_query, sql):
    if df.empty:
        return "No data found.", ""

    table = re.search(r'FROM\s+([a-z_]+)', sql, re.I)
    table = table.group(1) if table else "multiple"
    rows = len(df)

    stats = [f"**Rows:** {rows}"]
    num_cols = df.select_dtypes(include='number').columns.tolist()
    if num_cols:
        col = num_cols[0]
        stats.extend([
            f"**Mean / Average:** {df[col].mean():,.2f}",
            f"**Median:** {df[col].median():,.2f}",
            f"**Mode:** {df[col].mode().iloc[0] if not df[col].mode().empty else 'N/A'}"
        ])

    stats_text = "\n".join(stats)
    return f"{stats_text}\n\nüí° Business Insight: Strong performance observed.", stats_text

In [7]:
def generate_plot(df, user_query):
    if df.empty:
        return None

    title = f"üìà {user_query[:75]}"
    cols = df.columns.tolist()
    num_cols = df.select_dtypes(include='number').columns.tolist()

    # Year column handling
    year_cols = [c for c in cols if any(x in c.lower() for x in ['year', 'purchase_year'])]
    if year_cols and num_cols:
        year_col = year_cols[0]
        val_col = num_cols[0]
        df_plot = df.copy()
        df_plot[year_col] = df_plot[year_col].astype(str)
        fig = px.bar(df_plot, x=year_col, y=val_col, title=title, text=val_col,
                     color=year_col, color_discrete_sequence=px.colors.qualitative.Bold)
        fig.update_traces(texttemplate='%{text:,.0f}', textposition='outside')
        fig.update_layout(
            height=620,
            template="plotly_white",
            xaxis_title="Year",
            yaxis_title=val_col.replace("_", " ").title() + " (BRL)"
        )
        return fig

    # Time series
    date_cols = [c for c in cols if any(k in c.lower() for k in ['date','time','month'])]
    if date_cols and num_cols:
        df2 = df.copy()
        df2[date_cols[0]] = pd.to_datetime(df2[date_cols[0]], errors='coerce')
        df2['year'] = df2[date_cols[0]].dt.year.astype(str)
        fig = px.line(df2, x=date_cols[0], y=num_cols[0], color='year',
                      color_discrete_sequence=px.colors.qualitative.Bold,
                      title=title, markers=True)
        fig.update_layout(
            height=620,
            template="plotly_white",
            xaxis_title="Date",
            yaxis_title=num_cols[0].replace("_", " ").title()
        )
        return fig

    # Fallback
    if len(num_cols) >= 2:
        fig = px.scatter(df, x=num_cols[0], y=num_cols[1], trendline="ols", title=title)
        fig.update_layout(xaxis_title=num_cols[0].replace("_", " ").title(),
                          yaxis_title=num_cols[1].replace("_", " ").title())
        return fig
    elif num_cols:
        return px.histogram(df, x=num_cols[0], title=title)
    return px.bar(df.head(15), x=cols[0], y=cols[1] if len(cols)>1 else None, title=title)

In [None]:
llm_chat_history = []
last_df = None
last_date_col = None
last_num_col = None

def respond(message, history):
    global last_df, last_date_col, last_num_col, llm_chat_history
    history = history or []
    history = history + [[message, "ü§î Generating SQL..."]]

    try:
        q_emb = np.array([co.embed(model="embed-v4.0", texts=[message], input_type="search_query", embedding_types=["float"]).embeddings.float[0]]).astype('float32')
        _, I = index.search(q_emb, k=6)
        schema_context = "\n---\n".join([documents[i] for i in I[0]])

        sql = generate_sql(message, schema_context, llm_chat_history)
        df = run_query(sql)
        last_df = df.copy()

        explanation, _ = explain_results(df, message, sql)
        fig = generate_plot(df, message)

        date_cols = [c for c in df.columns if any(k in c.lower() for k in ['date','month','year','time'])]
        if date_cols and len(df.select_dtypes(include='number').columns) > 0:
            last_date_col = date_cols[0]
            last_num_col = df.select_dtypes(include='number').columns[0]

        llm_chat_history.append({"query": message, "sql": sql})

        bot_response = "**Generated SQL:**\n```sql\n" + sql + "\n```\n\n" + explanation
        history[-1][1] = bot_response
        return history, fig

    except Exception as e:
        friendly = "Sorry, this query is not related to this Database.\n\nTry asking about:\n‚Ä¢ Total revenue\n‚Ä¢ Customer locations by state\n‚Ä¢ Monthly order trends\n‚Ä¢ Top products\n‚Ä¢ Delivery time"
        history[-1][1] = friendly
        return history, None


def run_prophet_forecast():
    global last_df, last_date_col, last_num_col
    if last_df is None or last_date_col is None:
        return "Ask a time-series question first", None
    df_p = last_df[[last_date_col, last_num_col]].copy()
    df_p.columns = ['ds', 'y']
    df_p['ds'] = pd.to_datetime(df_p['ds'])
    m = Prophet(yearly_seasonality=True, weekly_seasonality=False)
    m.fit(df_p)
    future = m.make_future_dataframe(periods=180)
    forecast = m.predict(future)
    fig = plot_plotly(m, forecast)
    fig.update_layout(title="üîÆ Prophet 6-Month Forecast", height=680, template="plotly_white")
    return "‚úÖ Prophet forecast generated!", fig


with gr.Blocks(title="Olist Data Dictionary Agent", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ü§ñ Olist Data Dictionary Agent\n**Ask any business question ‚Üí Safe SQL ‚Üí Clean Charts**")

    chatbot = gr.Chatbot(height=620, show_label=False)
    msg = gr.Textbox(placeholder="Ask any business question...", label="Your Question")

    with gr.Row():
        send_btn = gr.Button("Send", variant="primary")
        forecast_btn = gr.Button("üîÆ Prophet Forecast", variant="stop")
        clear_btn = gr.Button("Clear Chat", variant="secondary")

    plot_output = gr.Plot(label="üìä Visualization")

    gr.Markdown("### Quick Questions")
    with gr.Row():
        for q in ["Where are our customers located?", "What is our total revenue?", "Show monthly order trends", "Top 10 best-selling products?", "Average delivery time by state?"]:
            gr.Button(q, size="sm").click(fn=lambda x=q: x, outputs=msg)

    send_btn.click(respond, [msg, chatbot], [chatbot, plot_output])
    msg.submit(respond, [msg, chatbot], [chatbot, plot_output])
    forecast_btn.click(run_prophet_forecast, outputs=[chatbot, plot_output])
    clear_btn.click(lambda: ([], None), outputs=[chatbot, plot_output])

demo.launch(share=True, debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://2ffd40c01c52125656.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
