In [1]:
! pip install -q langgraph  pysqlite3-binary openai matplotlib
import os
os.environ["OPENAI_API_KEY"] = "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"

In [2]:
import os
import time
import sqlite3
from typing import Optional, TypedDict, Dict, Any

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from google.colab import files
from openai import OpenAI
from langgraph.graph import StateGraph, END

In [3]:
# ----------------------------------------------------
# Upload CSV & Prep data
# ----------------------------------------------------
try:
  uploaded = files.upload()
  csv_path = list(uploaded.keys())[0]
except Exception:
   print('Please upload your data file.')
   try:
    uploaded = files.upload()
    csv_path = list(uploaded.keys())[0]
   except Exception:
    raise FileNotFoundError('No files uploaded after second attempt.')

Saving synthetic_bank_customer_data(in).csv to synthetic_bank_customer_data(in) (3).csv


In [4]:
# ----------------------------------------------------
# Load Data and Perform Basic Transformations
# ----------------------------------------------------
df = pd.read_csv(csv_path)

# Clean & transform data
for col in ['Name', 'Surname', 'Gender', 'Region', 'Job Classification']:
  if col in df.columns:
    df[col] = df[col].astype(str).str.strip()


df['Date Joined'] = pd.to_datetime(df['Date Joined'], format = '%d.%b.%y')
df['Year Month'] = df['Date Joined'].dt.to_period('M').astype(str) # YYYY-MM


for c, t in [('Customer ID', 'int64'), ('Age', 'int64'), ('Balance', 'float64')]:
  if c in df.columns:
    df[c] = df[c].astype(t)

df['Age Group'] = pd.cut(df['Age'], bins = [0, 25, 35, 45, 55, 60, 100],
                         labels = ['18-25', '26-35', '36-45', '46-55', '56-65', '65+'],
                         right = True)
def check_yoy_availability(_df: pd.DataFrame) -> bool:
  if 'Date Joined' not in _df.columns:
    return False
  years = pd.to_datetime(_df['Date Joined']).dt.year.dropna().unique()
  return len(years) > 1

In [5]:
# df.head(2)


In [6]:
# ----------------------------------------------------
# Save CSV to SQLite DB
# ----------------------------------------------------
# Saving to SQLite
db_path = 'bank_customer.db'
with sqlite3.connect(db_path) as conn:
  df.to_sql('customers', conn, if_exists = 'replace', index = False)

# SQL Helper
def run_sql_query(query: str, db_path: Optional[str] = db_path) -> pd.DataFrame:
  with sqlite3.connect(db_path) as conn:
    return pd.read_sql_query(query, conn)

In [7]:
# ----------------------------------------------------
# Graph State & Helpers
# ----------------------------------------------------
class QAState(TypedDict, total = False):
  question: str
  resolved_question: Optional[str]
  assumption: Optional[str]
  schema_hint: Optional[str]
  sql: Optional[str]
  error: Optional[str]
  intent: Optional[str]
  result: Optional[pd.DataFrame]
  df_for_plot: Optional[pd.DataFrame]
  viz_note: Optional[str]
  answer: Optional[str]
  log: Dict[str, Any]

In [8]:
# ----------------------------------------------------
# Setting up OpenAI Client, Configs & Intent
# ----------------------------------------------------

client = OpenAI()

ALLOWED_COLS = []
REQUIRED_BRACKETS = {"Customer ID", "Job Classification", "Date Joined", "Year Month", "Age Group"}
GROWTH_MOM_KEYS = {"mom", "month over month", "month-on-month","m/m"}
GROWTH_YOY_KEYS = {"yoy", "year over year", "year-over-year", "y/y"}
PREDICT_KEYS = {"predict", "forecast", "trend", "future", "project", "projection"}

# Detect Intention of Question
def detect_intent(ques: str) -> str:
  ql = (ques or "").lower()
  if any(w in ql for w in PREDICT_KEYS):
    return 'predict'
  elif any(w in ql for w in GROWTH_YOY_KEYS):
    return 'growth_yoy'
  elif any(w in ql for w in GROWTH_MOM_KEYS) or 'growth' in ql:
    return 'growth_mom'
  else:
    return 'generic'



In [9]:
# ----------------------------------------------------
# Nodes
# ----------------------------------------------------

# 1. Schema auto-discovery for validating schema_hint from DB
def discover_schema(state: QAState) -> QAState:
  pragma = run_sql_query("PRAGMA table_info(customers);")
  cols = pragma['name'].tolist()
  global ALLOWED_COLS
  ALLOWED_COLS = cols
  # Building Schema hint
  lines = ["TABLE customers columns (use EXACT names, bracket spaces):"]
  for c in cols:
    if any(ch in c for ch in [" ", "-", "_"]):
      lines.append(f"- [{c}] (type inferred")
    else:
      lines.append(f"- {c} (type inferred")
  state['schema_hint'] = "\n".join(lines)
  return state


# 2. Clarifying on growth questions
def clarify(state: QAState) -> QAState:
  q = state['question']
  intent = detect_intent(q)
  # -- Check YoY availability
  years = []
  if 'Date Joined' in df.columns:
    years = sorted(pd.to_datetime(df['Date Joined']).dt.year.dropna().unique())
  if intent == 'growth_yoy' and len(years)<= 1:
    state['intent'] = 'growth_yoy_unavailable'
    state['answer'] = (
        f"Year-over-Year growth unavailable.Only"
        f"{', '.join(map(str, years)) or 'no'} year(s) of data found."
    )
    return state
  # -- Assuming MoM only if question without clear timeframe.
  need_assumption = (intent == "growth_mom" and not any(k in q.lower() for k in ['yoy', 'year', 'month', '201', '202']))
  state['assumption'] = "Assumed month-over-month over the available period." if need_assumption else None
  state['resolved_question'] = q + " Assume month-over-month over the available period." if need_assumption else q
  state['intent'] = intent
  return state


# 3. LLM to SQL, using discovered schema.
def llm_to_sql(state: QAState) -> QAState:
  if state.get('intent') == 'growth_yoy_unavailable' and state.get('answer'):
    return state
  intent = state.get('intent', 'generic')
  schema_hint = state.get('schema_hint','')
  question = state.get('resolved_question') or state['question']
  if intent == 'growth_mom':
    user_prompt = (
        "Return ONLY executable SQLite SQL. No prose/markdown. Use table `customers`.\n"
        "Use exact column names. Bracket spaces like [Customer ID].\n"
        "Compute month-over-month growth using counts of [Customer ID].\n"
        "If comparing months, derive next month with date([Year Month] || '-01', '+1 month').\n\n"
        f"Question: {question}\n\n{schema_hint}"
    )
  if intent == 'growth_yoy':
    user_prompt = (
        "Return ONLY executable SQLite SQL. No prose/markdown. Use table `customers`.\n"
        "Use exact column names. Bracket spaces like [Customer ID].\n"
        "Compute year-over-year growth using counts of [Customer ID].\n"
        "Compare same month vs previous year via date([Year Month] || '-01', '-1 year').\n\n"
        f"Question: {question}\n\n{schema_hint}"
    )
  else:
    user_prompt = (
        "Return ONLY executable SQLite SQL. No prose/markdown. Use table `customers`.\n"
        "Use exact column names. Bracket spaces like [Customer ID].\n"
        "Prefer CTEs for clarity.\n\n"
        f"Question: {question}\n\n{schema_hint}"
    )

  response = client.chat.completions.create(
      model = 'gpt-4o',
      messages = [
          {'role': 'system', 'content': "Return ONLY SQL for SQLite over the `customers` table."},
           {'role': 'user', 'content': user_prompt}
                  ],
      temperature = 0
  )
  sql_query = response.choices[0].message.content.strip()
  # Remove markdown from sql
  if sql_query.startswith("```sql"):
      sql_query = sql_query[6:]
  if sql_query.endswith("```"):
      sql_query = sql_query[:-3]
  state['sql'] = sql_query.strip()
  return state

# 4. SQL validators and repair
def validate_sql(state: QAState) -> QAState:
  sql_text = (state.get('sql') or '')
  sql_lower = sql_text.lower().strip()
  # reject muilti statement SQL
  if sql_lower.count(';') > 1:
    state['error'] = "Multiple SQL statements detected."
    return state
  # prevent table update statements
  if any(w in sql_lower for w in ['drop', 'delete', 'update', 'insert','alter', 'pragma', 'attach', 'detach']):
    state['error'] = "Unsafe SQL operation detected."
    return state
  # enfore brackets for spaced columns
  for col in ALLOWED_COLS:
    if ' ' in col:
      p = col.lower().replace(' ','')
      if p in sql_lower and f"[{col.lower()}]" not in sql_lower:
        state['error'] = f"Column requires brackets: [{col}]"
        return state
  # limit rows for data view requests
  if 'select' in sql_lower and ' limit' not in sql_lower:
    state['sql'] = state['sql'].rstrip(';') + ' LIMIT 20'
  state['error'] = None
  return state

def repair_sql(state: QAState) -> QAState:
  if not state.get('error'):
    return state
  hint = state['error']
  schema_hint = state.get('schema_hint', '')
  orig = state.get('sql', '')
  resp = client.chat.completions.create(
      model = 'gpt-4o',
      messages = [
          {'role': 'system', 'content': "Return ONLY SQL for SQLite over the `customers` table."},
          {'role': 'user', 'content': f"Fix this SQL issue: {hint}. Use exact column names and brackets.\n\nOriginal SQL: \n{orig}\n\n{schema_hint}"}
      ],
      temperature = 0
  )
  sql = resp.choices[0].message.content.strip()
  if sql.startswith("```sql"):
      sql = sql[6:]
  if sql.endswith("```"):
      sql = sql[:-3]
  state['sql'] = sql.strip()
  return state

# 5. Execute SQL with observability
def execute_sql(state : QAState) -> QAState:
  # for no runs
  if not state.get('sql'):
    state['result'] = pd.DataFrame()
    state['log'] = {'question': state.get('question'), 'rows': 0, 'cols': 0, 'duration_sec':0}
    return state
  # query log info
  t0 = time.time()
  query = state.get('sql','')
  df_result = run_sql_query(query) if query else pd.DataFrame()
  duration_time = time.time() - t0
  state['result'] = df_result
  state['sql'] = query
  state['log'] = {
      'question': state.get('question'),
      'rows': int(df_result.shape[0]) if isinstance(df_result, pd.DataFrame) else 0,
      'cols': int(df_result.shape[1]) if isinstance(df_result, pd.DataFrame) else 0,
      'duration_sec': round(duration_time, 4)}
  return state

# 6. Visualize output (only if asked)
def visualize(state: QAState) -> QAState:
  q = (state.get('question') or '').lower()
  needs_viz = any(w in q for w in ['chart', 'plot', 'graph', 'show trend', 'visualize', 'line chart', 'bar chart', 'visualise'])
  df_res = state.get('result')
  # no viz if not requested or result empty
  if not needs_viz or not isinstance(df_res, pd.DataFrame) or df_res.empty:
    return state
  num_cols = df_res.select_dtypes(include=[np.number]).columns.tolist()
  date_cols = df_res.select_dtypes(include=["datetime64[ns]", "datetime64[ns, UTC]"]).columns.tolist()
  cat_cols = [c for c in df_res.columns if c not in (num_cols + date_cols)]

  try:
    if date_cols and num_cols:
      # line chart
      x, y = date_cols[0], num_cols[0]
      df_plot = df_res.sort_values(x)
      plt.figure()
      plt.plot(df_plot[x], df_plot[y])
      plt.title(f"{y} over time")
      plt.xlabel(x)
      plt.ylabel(y)
      plt.xticks(rotation = 45)
      plt.tight_layout()
      plt.show()
      state['viz note'] = f"Line chart of {y} by {x} displayed."
      state['df_for_plot'] = df_plot
    elif cat_cols and num_cols:
      # bar chart
      x, y = cat_cols[0], num_cols[0]
      df_plot = df_res.groupby(x, as_index = False)[y].sum()
      plt.figure()
      plt.bar(df_plot[x], df_plot[y])
      plt.title(f"{y} by {x}")
      plt.xlabel(x)
      plt.ylabel(y)
      plt.xticks(rotation = 45)
      plt.tight_layout()
      plt.show()
      state['viz note'] = f"Bar chart of {y} by {x} displayed."
      state['df_for_plot'] = df_plot
  except Exception:
    pass
  return state


In [10]:
# 7. Explainable (Readable) Answer Summaries
def explain(state: QAState) -> QAState:
  df_result = state.get('result')
  query = state.get('sql','')
  question = (state.get('question') or "").lower().strip()
  assumption = state.get('assumption')
  log = state.get('log', {})
  # if there was already a msg for a no output query, return it
  if state.get('answer') and not state.get('sql'):
    return state
  if not isinstance(df_result, pd.DataFrame) or df_result.empty:
    state['answer'] = ((f"Assumption: {assumption}\n\n" if assumption else "")+
                       "No rows matched your question.\n\n" +
                       f"SQL executed:\n{query}\n\n" +
                       f"Observability and rows: 0, duration: {log.get('duration_sec','?')}s"
                       )
    return state

  # Formatting large numbers with commas and floats with 2 decimal points.
  def fmt_num(x):
    if pd.isna(x): return "NA"
    try:
      xf = float(x)
      return f"{int(xf):,}"if xf.is_integer() else f"{float(xf):,.2f}"
    except Exception:
      return str(x)
  # Date, Numeric and Categorical variables
  num_cols = df_result.select_dtypes(include=[np.number]).columns.tolist()
  date_cols = df_result.select_dtypes(include=["datetime64[ns]", "datetime64[ns, UTC]"]).columns.tolist()
  cat_cols = [c for c in df_result.columns if c not in (num_cols + date_cols)]
  # Sorting result if asked
  sort_desc_words = {'largest', 'highest', 'biggest', 'most', 'top', 'max', 'maximum', "latest", "newest"}
  sort_asc_words = {'smallest', 'lowest', 'least', 'bottom', 'min', 'minimum', "earliest", "oldest"}
  metric = num_cols[0] if num_cols else None
  if metric:
    if any(word in question for word in sort_desc_words):
      df_result = df_result.sort_values(by = metric, ascending = False)
    elif any(word in question for word in sort_asc_words):
      df_result = df_result.sort_values(by = metric, ascending = True)

  # Summaries
  # Case A: Direct single value answers
  if df_result.shape == (1,1):
    val = df_result.iat[0, 0]
    summary = (f"There are {fmt_num(val)}." if any(k in question for k in ['how', 'many', 'count']) else f"The result is {fmt_num(val)}.")
  # Case B: 1 row, many columns -> key: value summary
  elif df_result.shape[0] == 1 and df_result.shape[1] > 1:
    pairs = [f"{col}: {fmt_num(df_result.iloc[0][col])}" if col in num_cols
             else f"{col}: {df_result.iloc[0][col]}"
             for col in df_result.columns]
    summary = "The result row is: " + ", ".join(pairs)
  # Case C: 2 columns -> 1 numeric, 1 categorical
  elif df_result.shape[1] == 2 and len(num_cols) == 1 and len(cat_cols) == 1:
    group_col_name, metric_col_name = cat_cols[0], num_cols[0]
    lead = "On average, " if any(word in metric_col_name.lower() for word in ["avg", "average", "mean"]) else ""
    pairs = [f"{g} have {fmt_num(v)}" for g, v in zip(df_result[group_col_name].astype(str), df_result[metric_col_name])]
    summary = lead + ", ".join(pairs) +"."
  # Case D: Date wise summary
  elif len(date_cols) >= 1 and df_result.shape[1] <= 3:
    parts = []
    for dcol in date_cols:
      try:
        dmin = pd.to_datetime(df_result[dcol].min())
        dmax = pd.to_datetime(df_result[dcol].max())
        parts.append(f"{dcol} between {dmin} and {dmax}")
      except Exception:
        parts.append(f"{dcol} (date column)")
    summary = "; ".join(parts) +"."
  # Case E: 3 colums -> 2 categorical, 1 numeric
  elif len(cat_cols) >= 1 and len(num_cols) >= 1:
    cats, mcol = cat_cols[:2], num_cols[0]
    items = []
    for _, row in df_result.head(5).iterrows():
      items.append(f"{' & '.join(f'{c}={row[c]}' for c in cats)}: {fmt_num(row[mcol])}")
    prefix = 'Top ' if any(word in question for word in sort_desc_words) else "Lowest " if any(word in question for word in sort_asc_words) else "Sample"
    summary = f"{prefix}{min(5, len(df_result))} by {mcol}: " + ", ".join(items) + "."
  else:
    r, c = df_result.shape
    answer_text = f"The query returned {r} row(s) and {c} column(s)."

  preview = df_result.head().to_string(index = False)
  note = (f"\n\nAssumption: {assumption}" if assumption else "")
  viz = (f"\n\n{state.get('viz_note')}" if state.get('viz_note') else "")
  obs = f"\n\nObservability and rows: {log.get('rows','?')}, cols: {log.get('cols','?')}, duration: {log.get('duration_sec','?')}s"
  state['df_result'] = df_result
  state['answer'] = (
      f"{summary}{note}{viz}\n\n"
      f"SQL executed:\n {query}\n\n"
      f"\nPreview of result(s):\n{preview}{obs}")

  return state



In [11]:
# ----------------------------------------------------
# Graph Wires
# ----------------------------------------------------

# Nodes (ids, functions)
workflow = StateGraph(QAState)
workflow.add_node('discover_schema', discover_schema)
workflow.add_node('clarify', clarify)
workflow.add_node("generate_sql", llm_to_sql)
workflow.add_node("validate", validate_sql)
workflow.add_node("repair", repair_sql)
workflow.add_node("execute", execute_sql)
workflow.add_node("visualize", visualize)
workflow.add_node("explain", explain)

# Entry point
workflow.set_entry_point("discover_schema")

# Edges
workflow.add_edge("discover_schema", "clarify")
workflow.add_edge("clarify", "generate_sql")
workflow.add_edge("generate_sql", "validate")

def needs_repair(state: QAState) -> str:
  return "repair" if state.get('error') else "execute"
workflow.add_conditional_edges("validate", needs_repair, {"repair": "repair", "execute": "execute"})

workflow.add_edge("repair", "execute")
workflow.add_edge("execute", "visualize")
workflow.add_edge("visualize", "explain")
workflow.add_edge("explain", END)

app = workflow.compile()

In [13]:
from IPython.display import display
# ----------------------------------------------------
# Ask
# ----------------------------------------------------

# Question Function:
def ask(question: str, mode: str = 'text'):
  out = app.invoke(QAState(question = question))
  if mode == 'table':
    return out.get('result')
  return out.get('answer','')

user_question = input("Ask a question: ")
mode_choice = input("Select mode (text/table)?: ").strip().lower() or 'text'
print("\n Running Analysis...\n")
result = ask(user_question, mode = mode_choice)

if mode_choice == 'table':
  display(result)
else:
  print(result)

Ask a question: Which regions have the highest concentration of specific job classifications?
Select mode (text/table)?: text

 Running Analysis...

Top 4 by JobCount: Region=England & Job Classification=White Collar: 1,501, Region=Scotland & Job Classification=Blue Collar: 544, Region=Wales & Job Classification=White Collar: 305, Region=Northern Ireland & Job Classification=Other: 105.

SQL executed:
 WITH JobCounts AS (
    SELECT 
        [Region], 
        [Job Classification], 
        COUNT([Customer ID]) AS JobCount
    FROM 
        customers
    GROUP BY 
        [Region], 
        [Job Classification]
),
MaxJobCounts AS (
    SELECT 
        [Region], 
        MAX(JobCount) AS MaxJobCount
    FROM 
        JobCounts
    GROUP BY 
        [Region]
)
SELECT 
    jc.[Region], 
    jc.[Job Classification], 
    jc.JobCount
FROM 
    JobCounts jc
JOIN 
    MaxJobCounts mjc
ON 
    jc.[Region] = mjc.[Region] AND 
    jc.JobCount = mjc.MaxJobCount
ORDER BY 
    jc.[Region], 
    jc.