In [None]:
import json
import sqlite3
import re
import os
from sqlalchemy import create_engine, inspect
from huggingface_hub import InferenceClient
from google.colab import userdata

my_token = userdata.get('HF_TOKEN')
client = InferenceClient(model="Qwen/Qwen3-8B", token=my_token) 

def clean_sql_query(text):
    if not text:
        return ""
    match = re.search(r"```(?:sql)?\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()
    text_upper = text.upper()
    keywords = ["SELECT ", "WITH ", "VALUES ", "INSERT ", "UPDATE ", "DELETE "]
    start_index = -1
    for kw in keywords:
        idx = text_upper.find(kw)
        if idx != -1 and (start_index == -1 or idx < start_index):
            start_index = idx
    if start_index != -1:
        raw_sql = text[start_index:]
        last_semicolon = raw_sql.rfind(";")
        if last_semicolon != -1:
            raw_sql = raw_sql[:last_semicolon + 1]
        return raw_sql.strip()
    return text.strip()

def get_schema_string(inspector):
    schema_str = "Database Schema:\n"
    try:
        for table in inspector.get_table_names():
            schema_str += f"Table: {table}\n"
            columns = inspector.get_columns(table)
            for col in columns:
                schema_str += f"  - {col['name']} ({col['type']})\n"   
    except:
        return ""
    return schema_str

def run_query(db_path, query):
  conn = sqlite3.connect(db_path)
  try:
    cursor = conn.cursor()
    cursor.execute(query)
    rows = cursor.fetchall()
    conn.close()

    # Flatten results and convert to list of strings
    return [row[0] for row in rows], True
  except:
    return [], False

def compute_execution_accuracy(gt_results, predict_results):
    num_correct = 0
    num_queries = len(gt_results)
    mismatch_idx = []

    for i, result in enumerate(gt_results):
        if set(result['results']) == set(predict_results[i]['results']):
            num_correct += 1
        else:
            mismatch_idx.append(i)
    acc = num_correct / num_queries
    return acc

def process_database(db_name):
    print(f"\n{'='*40}\nProcessing Database: {db_name}\n{'='*40}")
    path_json = f"dataset/{db_name}/{db_name}.json"
    path_sql = f"dataset/{db_name}/{db_name}.sqlite"
    if not os.path.exists(path_json) or not os.path.exists(path_sql):
        print(f"Skipping {db_name}: Missing files.")
        return []

    with open(path_json, "r") as f:
        questions = json.load(f)

    q_ids = {}
    if os.path.exists("golds.json"):
        with open("golds.json", 'r') as v:
            golds_list = json.load(v)
            q_ids = {g["question_id"]: g for g in golds_list}

    db_uri = f"sqlite:///{path_sql}"
    engine = create_engine(db_uri)
    inspector = inspect(engine)
    schema_text = get_schema_string(inspector)

    traces = []
    for i, q in enumerate(questions):
        q_id = q.get("question_id")
        question_text = q.get("questions", "")
        evidence = q.get("evidence", "")
        difficulty = q.get("difficulty", "unknown")
        gt_entry = q_ids.get(q_id, {})
        gt_query = gt_entry.get("target_sql", "")
        if gt_query is None:
            gt_query = ""

        print(f"--- Q{i+1} ({db_name}) ---")

        system_prompt = """You are an expert Data Scientist specialized in Text-to-SQL tasks.
You will be given a task to solve as best you can.
Task: Convert the user's question into a valid SQL query based on the schema.
Instructions:
1. Output ONLY the SQL query.
2. Do not explain your answer.
"""
        user_input = f"{schema_text}\n\nQuestion: {evidence} {question_text}"
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_input}
        ]

        error_count = 0
        raw_output = ""
        try:
            response = client.chat_completion(messages=messages, max_tokens=4096, temperature=0.01)
            content = response.choices[0].message.content
            raw_output = content if content else ""
        except Exception as e:
            raw_output = f"ERROR_API_FAILED: {str(e)}"
            print(f"Failed to get response for Q{i+1}")
            error_count = 1

        pred_query = raw_output if "ERROR_API_FAILED" in raw_output else clean_sql_query(raw_output)

        exec_acc = 0
        if error_count == 0 and gt_query and not pred_query.startswith("ERROR"):
            rows_gt, _ = run_query(path_sql, gt_query)
            rows_pred, is_valid_sql = run_query(path_sql, pred_query)
            if is_valid_sql:
                gt_res = [{"results": rows_gt}]
                pred_res = [{"results": rows_pred}]
                exec_acc = compute_execution_accuracy(gt_res, pred_res)

        trace_entry = {
            "question_id": q_id,
            "difficulty": difficulty,
            "pred_query": pred_query,
            "target_query": gt_query,
            "execution_accuracy": int(exec_acc),
        }
        traces.append(trace_entry)

        display = pred_query[:60] + "..." if len(pred_query) > 60 else pred_query
        print(f"SQL: {display}\nAcc: {exec_acc}")

    output_dir = "traces_baseline"
    os.makedirs(output_dir, exist_ok=True)
    with open(f"{output_dir}/{db_name}.json", "w") as f:
        json.dump(traces, f, indent=2, ensure_ascii=False)

    return traces


if __name__ == "__main__":
    dataset_root = "dataset"
    master_traces = []

    if os.path.exists(dataset_root):
        all_dbs = [d for d in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, d))]
        for db in all_dbs:
            try:
                db_traces = process_database(db)
                master_traces.extend(db_traces)
            except Exception as e:
                print(f"Error processing {db}: {e}")
                continue

        with open("traces_baseline/traces_baseline.json", "w") as f:
            json.dump(master_traces, f, indent=2, ensure_ascii=False)
        print(f"Completed. Saved {len(master_traces)} traces.")
    else:
        print("Dataset folder not found.")
