In [None]:
import os
import random
from datetime import date, timedelta
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Date, Float
import sqlite3
import re
import ast
import json
import contextlib

In [None]:


def extract_sql_code_block(text: str) -> str:
    """
    Extracts SQL from markdown-style code blocks if present.
    If the input looks like a list of tuples in string form, it will:
    - Parse the list
    - Quote all inner values
    - Sort the list
    - Return a string of tuples joined by commas
    """
    text = text.strip()

    # Handle SQL code block: ```sql ... ```
    match = re.search(r"```sql\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE)
    if match:
        sql = match.group(1).replace('\n', ' ').strip()
        return sql

    # Handle list-of-tuples-as-string case
    try:
        # Safely evaluate the string to a Python list
        parsed = ast.literal_eval(text)
        if isinstance(parsed, list) and all(isinstance(item, tuple) for item in parsed):
            # Convert all inner items to strings and quote them
            normalized = [
                tuple(str(x).strip() for x in row)
                for row in parsed
            ]
            # Sort for consistent comparison
            sorted_rows = sorted(normalized)
            return ", ".join(str(row) for row in sorted_rows)
    except Exception:
        pass

    # Fallback: remove brackets if it looks like a wrapped string
    if text.startswith("[") and text.endswith("]"):
        text = text[1:-1].strip().strip("'\"")

    return text.replace('\n', ' ').strip()

In [None]:
import getpass
import os

if "TOGETHER_API_KEY" not in os.environ:
    os.environ["TOGETHER_API_KEY"] = getpass.getpass("yourAPIKeyHere")

In [None]:
import os
os.environ["TOGETHER_API_KEY"] = "yourAPIKeyHere"

In [None]:
os.environ["TOGETHER_API_KEY"]

In [None]:
from langchain_together import ChatTogether

llm = ChatTogether(
    model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # other params...
)

In [None]:
llm.invoke("who are you?").content

In [None]:
from langchain.chat_models import ChatOpenAI

llm = ChatOpenAI(
    model_name="togethercomputer/llama-3-70b",
    openai_api_base="https://api.together.xyz/v1",
    openai_api_key="your_together_ai_api_key"
)

In [None]:
# Tugberk Groq API Key
#GROQ_API_KEY=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
#-------------
#GROQ_API_KEY=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
# GROQ_API_KEY=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
#--------Shivam
#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

In [None]:
os.environ["GROQ_API_KEY"]="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"

In [None]:
from langchain_groq import ChatGroq

#os.environ['GROQ_API_KEY'] = userdata.get('GROQ_API_KEY')
#os.environ['GROQ_API_KEY'] = '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'

#llm = ChatGroq(model='llama-3.3-70b-versatile')

#print(llm.invoke('who are you?').content)

In [None]:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:////Users/onurcanmemis/Downloads/spider_data/database/concert_singer/concert_singer.sqlite")
print(db.dialect)
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

tools = toolkit.get_tools()

tools

In [None]:
for i in tools:
    print(i.name)
    print(i.description)

In [None]:
table_names = db.get_usable_table_names()
system_message = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer.
You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database.

The tables in the database is given as {table_names}. To start you should ALWAYS look at the tables and their columns in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
    dialect="SQLite",
    table_names=", ".join(table_names)
)
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, tools, prompt=system_message)

In [None]:
import ast

def extract_sql_and_result_from_agent_output(agent_output: dict) -> dict:
    """
    Extract the final SQL query and its returned value from LangGraph-style agent output.

    Returns a dictionary with:
        - 'gold_sql': SQL query as string
        - 'gold_result': formatted string of tuples
    """
    messages = agent_output.get("messages", [])

    sql_query = ""
    sql_result = ""

    for msg in messages:
        if hasattr(msg, "name") and msg.name == "sql_db_query":
            # This is the ToolMessage containing the SQL result
            raw_result = msg.content
            try:
                parsed = ast.literal_eval(raw_result)
                if isinstance(parsed, list) and all(isinstance(t, tuple) for t in parsed):
                    sorted_rows = sorted(tuple(str(i).strip() for i in row) for row in parsed)
                    sql_result = ", ".join(str(row) for row in sorted_rows)
            except Exception as e:
                sql_result = raw_result  # fallback to raw

        elif hasattr(msg, "tool_calls"):
            for tool_call in msg.tool_calls:
                if tool_call["name"] == "sql_db_query":
                    sql_query = tool_call["args"]["query"]

    return {
        "generated_sql": sql_query.strip(),
        "generated_result": sql_result.strip()
    }

In [None]:
answer=agent_executor.invoke(
            {"messages": [{"role": "user", "content": golden_data_easy[1]["question"]}]},
            config={"recursion_limit": 10}
        )

In [None]:
extract_sql_and_result_from_agent_output(answer)

In [None]:
events = agent_executor.stream(
    {"messages": [("user", golden_data_easy[0]["question"])]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()

## Hard Questions

In [None]:
import json
with open("extracted_sql_examples_hard.json", "r") as f:
    data = json.load(f)

golden_data_hard = [
    {"db_id": item["db_id"], "question": item["question"], "query": item["gold_sql"], "result": item["gold_result"]}
    for item in data
]

In [None]:
golden_data_hard[0]

In [None]:
import json
import os
from tqdm import tqdm
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langgraph.prebuilt import create_react_agent

save_path = "hard_answers.json"

# Load previous progress if available
if os.path.exists(save_path):
    with open(save_path, "r") as f:
        hard_answers = json.load(f)
else:
    hard_answers = []

start_index = len(hard_answers)

# Wrap tqdm around the remaining examples
for i in tqdm(range(start_index, len(golden_data_hard)), desc="Generating SQL", unit="q"):
    item = golden_data_hard[i]
    try:
        print(item["question"])

        # Setup DB first (important: get tables *after* db is loaded)
        db = SQLDatabase.from_uri(f"sqlite:////Users/onurcanmemis/Downloads/spider_data/database/{item['db_id']}/{item['db_id']}.sqlite")
        table_names = db.get_usable_table_names()

        system_message = """
        You are an agent designed to interact with a SQL database.
        Given an input question, create a syntactically correct {dialect} query to run,
        then look at the results of the query and return the answer.
        You can order the results by a relevant column to return the most interesting
        examples in the database. Never query for all the columns from a specific table,
        only ask for the relevant columns given the question.

        You MUST double check your query before executing it. If you get an error while
        executing a query, rewrite the query and try again.

        DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
        database.

        The tables in the database is given as {table_names}. To start you should ALWAYS look at the 
        tables and their columns in the database to see what you can query. Do NOT skip this step.

        Then you should query the schema of the most relevant tables.
        """.format(
            dialect="SQLite",
            table_names=", ".join(table_names)
        )

        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        tools = toolkit.get_tools()
        agent_executor = create_react_agent(llm, tools, prompt=system_message)

        answer = agent_executor.invoke(
            {"messages": [{"role": "user", "content": item["question"]}]},
            config={"recursion_limit": 10}
        )

        generated_query = extract_sql_and_result_from_agent_output(answer)["generated_sql"]
        generated_result = extract_sql_and_result_from_agent_output(answer)["generated_result"]

        hard_answers.append({
            "db_id": item["db_id"],
            "question": item["question"],
            "generated_sql": generated_query or "",
            "result": generated_result or "",
            "gold_sql": item["query"],
            "gold_result": item["result"]
        })

        print(f"✅ Success at index {i}")
        print(hard_answers[i])

        with open(save_path, "w") as f:
            json.dump(hard_answers, f, indent=2)

    except Exception as e:
        error_msg = str(e)
        if "GRAPH_RECURSION_LIMIT" in error_msg or "Recursion limit" in error_msg:
            print(f"⚠️ Recursion limit hit at index {i}. Skipping.")
            hard_answers.append({
                "db_id": item["db_id"],
                "question": item["question"],
                "generated_sql": "",
                "result": "",
                "gold_sql": item["query"]
            })
            print(hard_answers)
        elif "tool_use_failed" in error_msg or "table" in error_msg.lower():
            print(f"⚠️ Tool failure or table issue at index {i}. Skipping.")
            break
        else:
            print(f"❌ Error at index {i}: {e}")
            break

        with open(save_path, "w") as f:
            json.dump(hard_answers, f, indent=2)

## Medium Questions

In [None]:
import json
with open("extracted_sql_examples_medium.json", "r") as f:
    data = json.load(f)

golden_data_medium = [
    {"db_id": item["db_id"], "question": item["question"], "query": item["gold_sql"], "result": item["gold_result"]}
    for item in data
]

In [None]:
golden_data_medium[0]

In [None]:
import json
import os
from tqdm import tqdm
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langgraph.prebuilt import create_react_agent

save_path = "medium_answers.json"

# Load previous progress if available
if os.path.exists(save_path):
    with open(save_path, "r") as f:
        medium_answers = json.load(f)
else:
    medium_answers = []

start_index = len(medium_answers)

# Wrap tqdm around the remaining examples
for i in tqdm(range(start_index, len(golden_data_medium)), desc="Generating SQL", unit="q"):
    item = golden_data_medium[i]
    try:
        print(item["question"])

        # Setup DB first (important: get tables *after* db is loaded)
        db = SQLDatabase.from_uri(f"sqlite:////Users/onurcanmemis/Downloads/spider_data/database/{item['db_id']}/{item['db_id']}.sqlite")
        table_names = db.get_usable_table_names()

        system_message = """
        You are an agent designed to interact with a SQL database.
        Given an input question, create a syntactically correct {dialect} query to run,
        then look at the results of the query and return the answer.
        You can order the results by a relevant column to return the most interesting
        examples in the database. Never query for all the columns from a specific table,
        only ask for the relevant columns given the question.

        You MUST double check your query before executing it. If you get an error while
        executing a query, rewrite the query and try again.

        DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
        database.

        The tables in the database is given as {table_names}. To start you should ALWAYS look at the tables and their columns in the database to see what you
        can query. Do NOT skip this step.

        Then you should query the schema of the most relevant tables.
        """.format(
            dialect="SQLite",
            table_names=", ".join(table_names)
        )

        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        tools = toolkit.get_tools()
        agent_executor = create_react_agent(llm, tools, prompt=system_message)

        answer = agent_executor.invoke(
            {"messages": [{"role": "user", "content": item["question"]}]},
            config={"recursion_limit": 10}
        )

        generated_query = extract_sql_and_result_from_agent_output(answer)["generated_sql"]
        generated_result = extract_sql_and_result_from_agent_output(answer)["generated_result"]

        medium_answers.append({
            "db_id": item["db_id"],
            "question": item["question"],
            "generated_sql": generated_query or "",
            "result": generated_result or "",
            "gold_sql": item["query"],
            "gold_result": item["result"]
        })

        print(f"✅ Success at index {i}")
        print(medium_answers[i])

        with open(save_path, "w") as f:
            json.dump(medium_answers, f, indent=2)

    except Exception as e:
        error_msg = str(e)
        if "GRAPH_RECURSION_LIMIT" in error_msg or "Recursion limit" in error_msg:
            print(f"⚠️ Recursion limit hit at index {i}. Skipping.")
            medium_answers.append({
                "db_id": item["db_id"],
                "question": item["question"],
                "generated_sql": "",
                "result": "",
                "gold_sql": item["query"]
            })
            print(medium_answers)
        elif "tool_use_failed" in error_msg or "table" in error_msg.lower():
            print(f"⚠️ Tool failure or table issue at index {i}. Skipping.")
            break
        else:
            print(f"❌ Error at index {i}: {e}")
            break

        with open(save_path, "w") as f:
            json.dump(medium_answers, f, indent=2)

## Easy Questions

In [None]:
import json
with open("extracted_sql_examples_easy.json", "r") as f:
    data = json.load(f)

golden_data_easy = [
    {"db_id": item["db_id"], "question": item["question"], "query": item["gold_sql"], "result": item["gold_result"]}
    for item in data
]

In [None]:
golden_data_easy[0]

In [None]:
import json
import os
from tqdm import tqdm

save_path = "easy_answers.json"

# Load previous progress if available
if os.path.exists(save_path):
    with open(save_path, "r") as f:
        easy_answers = json.load(f)
else:
    easy_answers = []

start_index = len(easy_answers)

# Wrap tqdm around the remaining examples
for i in tqdm(range(start_index, len(golden_data_easy)), desc="Generating SQL", unit="q"):
    item = golden_data_easy[i]
    try:
        print(item["question"])
        system_message = """
        You are an agent designed to interact with a SQL database.
        Given an input question, create a syntactically correct {dialect} query to run,
        then look at the results of the query and return the answer.
        You can order the results by a relevant column to return the most interesting
        examples in the database. Never query for all the columns from a specific table,
        only ask for the relevant columns given the question.

        You MUST double check your query before executing it. If you get an error while
        executing a query, rewrite the query and try again.

        DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
        database.

        The tables in the database is given as {table_names}. To start you should ALWAYS look at the tables and their columns in the database to see what you
        can query. Do NOT skip this step.

        Then you should query the schema of the most relevant tables.
        """.format(
        dialect="SQLite",
        table_names=db.get_usable_table_names()
        )
        db = SQLDatabase.from_uri(f"sqlite:////Users/onurcanmemis/Downloads/spider_data/database/{item['db_id']}/{item['db_id']}.sqlite")
        toolkit = SQLDatabaseToolkit(db=db, llm=llm)
        tools = toolkit.get_tools()
        agent_executor = create_react_agent(llm, tools, prompt=system_message)
        answer=agent_executor.invoke({"messages": [{"role": "user", "content": item["question"]}]},
                                     config={"recursion_limit": 10})
        #answers.update({"query":extract_sql_code_block(answer.get("messages")[-3].content),"return":extract_sql_code_block(answer.get("messages")[-2].content)})
        generated_query=extract_sql_and_result_from_agent_output(answer)["generated_sql"]
        generated_result=extract_sql_and_result_from_agent_output(answer)["generated_result"]
        easy_answers.append({
            "db_id": item["db_id"],
            "question": item["question"],
            "generated_sql": generated_query or "",
            "result": generated_result or "",
            "gold_sql":item["query"],
            "gold_result":item["result"]
        })
        print(f"✅ Success at index {i}")
        print(easy_answers[i])

        # Save progress after each successful query
        with open(save_path, "w") as f:
            json.dump(easy_answers, f, indent=2)

    except Exception as e:
        error_msg = str(e)
        if "GRAPH_RECURSION_LIMIT" in error_msg or "Recursion limit" in error_msg:
            print(f"⚠️ Recursion limit hit at index {i}. Skipping.")
            easy_answers.append({
            "db_id": item["db_id"],
            "question": item["question"],
            "generated_sql": "",
            "result": ""})
            print(easy_answers[i])
        elif "tool_use_failed" in error_msg or "table" in error_msg.lower():
            print(f"⚠️ Tool failure or table issue at index {i}. Skipping.")
            break
        else:
            print(f"❌ Error at index {i}: {e}")
            break

        with open(save_path, "w") as f:
            json.dump(easy_answers, f, indent=2)

In [None]:
import json
import ast

# Load the merged JSON file
with open("easy_answers.json", "r") as f:
    data = json.load(f)

# Utility to parse result string into list of tuples
def parse_result_string(result_str):
    try:
        parsed = ast.literal_eval(f"[{result_str}]")  # convert comma-separated tuples into a list
        return parsed if isinstance(parsed, list) else []
    except Exception:
        return []

# Build prediction and gold lists
predictions = []
gold_data = []

for item in data:
    pred_sql = item.get("generated_sql", "")
    pred_result = parse_result_string(item.get("result", ""))

    gold_sql = item.get("gold_sql", "")
    gold_result = parse_result_string(item.get("gold_result", ""))

    predictions.append((pred_sql, pred_result))
    gold_data.append((gold_sql, gold_result))

# Check samples
print("Example prediction:", predictions[0])
print("Example gold:", gold_data[0])

## Evaluation

In [None]:
from typing import List, Tuple, Dict
import sqlparse

def normalize_sql(sql):
    sql = sql.strip().rstrip(';')
    parsed = sqlparse.format(sql, keyword_case='lower', strip_comments=True, reindent=True)
    return " ".join(parsed.strip().split()).lower()

def normalize_result(result):
    if not result or not isinstance(result, list):
        return set()

    normalized = set()
    for row in result:
        if row is None:
            continue  # skip invalid rows
        try:
            normalized.add(tuple(str(item).strip() for item in row))
        except Exception:
            continue  # skip malformed rows

    return normalized
def extract_components(sql):
    parsed = sqlparse.parse(sql)
    if not parsed:
        return {"select": set(), "from": set(), "where": set()}
    stmt = parsed[0]

    select_tokens = set()
    from_tokens = set()
    where_tokens = set()

    is_select = False
    is_from = False
    is_where = False

    for token in stmt.tokens:
        if token.is_group:
            for subtoken in token.flatten():
                tval = subtoken.value.lower().strip()
                if tval in ("select", "from", "where"):
                    is_select = tval == "select"
                    is_from = tval == "from"
                    is_where = tval == "where"
                    continue
                if is_select and subtoken.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Wildcard):
                    select_tokens.add(tval)
                elif is_from and subtoken.ttype in (sqlparse.tokens.Name,):
                    from_tokens.add(tval)
                elif is_where and subtoken.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Literal.Number.Integer, sqlparse.tokens.Operator.Comparison):
                    where_tokens.add(tval)
        else:
            tval = token.value.lower().strip()
            if tval in ("select", "from", "where"):
                is_select = tval == "select"
                is_from = tval == "from"
                is_where = tval == "where"
                continue
            if is_select and token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Wildcard):
                select_tokens.add(tval)
            elif is_from and token.ttype in (sqlparse.tokens.Name,):
                from_tokens.add(tval)
            elif is_where and token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Literal.Number.Integer, sqlparse.tokens.Operator.Comparison):
                where_tokens.add(tval)

    return {"select": select_tokens, "from": from_tokens, "where": where_tokens}

def jaccard_similarity(set1, set2):
    if not set1 and not set2:
        return 1.0
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0
def evaluate_predictions(
    predictions: List[Tuple[str, List[Tuple]]],
    gold_data: List[Tuple[str, List[Tuple]]]
) -> Dict[str, float]:
    """
    Evaluate a list of predicted SQL queries and their results against gold SQL queries and results.

    Each element in the predictions and gold_data lists is a tuple:
    (sql_query: str, result: List[Tuple])

    Returns:
    A dictionary with evaluation metrics.
    """
    exact_match_count = 0
    execution_match_count = 0
    total = len(predictions)

    jaccard_scores = {
        "select": 0.0,
        "from": 0.0,
        "where": 0.0
    }

    for (pred_sql, pred_result), (gold_sql, gold_result) in zip(predictions, gold_data):
        norm_pred_sql = normalize_sql(pred_sql)
        norm_gold_sql = normalize_sql(gold_sql)

        norm_pred_result = normalize_result(pred_result)
        norm_gold_result = normalize_result(gold_result)

        if norm_pred_sql == norm_gold_sql:
            exact_match_count += 1

        if norm_pred_result == norm_gold_result:
            execution_match_count += 1

        pred_components = extract_components(norm_pred_sql)
        gold_components = extract_components(norm_gold_sql)

        for key in jaccard_scores.keys():
            jaccard_scores[key] += jaccard_similarity(pred_components.get(key, set()), gold_components.get(key, set()))

    # Average jaccard scores
    for key in jaccard_scores:
        jaccard_scores[key] /= total

    metrics = {
        "total": total,
        "exact_match_accuracy": exact_match_count / total if total else 0.0,
        "execution_match_accuracy": execution_match_count / total if total else 0.0,
        "select_jaccard": jaccard_scores["select"],
        "from_jaccard": jaccard_scores["from"],
        "where_jaccard": jaccard_scores["where"]
    }

    return metrics

In [None]:
import json
import ast

# Load the merged JSON file
with open("hard_answers.json", "r") as f:
    data = json.load(f)

# Utility to parse result string into list of tuples
def parse_result_string(result_str):
    try:
        parsed = ast.literal_eval(f"[{result_str}]")  # convert comma-separated tuples into a list
        return parsed if isinstance(parsed, list) else []
    except Exception:
        return []

# Build prediction and gold lists
hard_predictions = []
hard_gold_data = []

for item in data:
    pred_sql = item.get("generated_sql", "")
    pred_result = parse_result_string(item.get("result", ""))

    gold_sql = item.get("gold_sql", "")
    gold_result = parse_result_string(item.get("gold_result", ""))

    hard_predictions.append((pred_sql, pred_result))
    hard_gold_data.append((gold_sql, gold_result))

# Check samples
print("Example prediction:", hard_predictions[0])
print("Example gold:", hard_gold_data[0])

In [None]:
evaluate_predictions(hard_predictions, hard_gold_data)

In [None]:
import json
import ast

# Load the merged JSON file
with open("medium_answers.json", "r") as f:
    data = json.load(f)

# Utility to parse result string into list of tuples
def parse_result_string(result_str):
    try:
        parsed = ast.literal_eval(f"[{result_str}]")  # convert comma-separated tuples into a list
        return parsed if isinstance(parsed, list) else []
    except Exception:
        return []

# Build prediction and gold lists
medium_predictions = []
medium_gold_data = []

for item in data:
    pred_sql = item.get("generated_sql", "")
    pred_result = parse_result_string(item.get("result", ""))

    gold_sql = item.get("gold_sql", "")
    gold_result = parse_result_string(item.get("gold_result", ""))

    medium_predictions.append((pred_sql, pred_result))
    medium_gold_data.append((gold_sql, gold_result))

# Check samples
print("Example prediction:", medium_predictions[0])
print("Example gold:", medium_gold_data[0])

In [None]:
evaluate_predictions(medium_predictions, medium_gold_data)

In [None]:
import json
import ast

# Load the merged JSON file
with open("easy_answers.json", "r") as f:
    data = json.load(f)

# Utility to parse result string into list of tuples
def parse_result_string(result_str):
    try:
        parsed = ast.literal_eval(f"[{result_str}]")  # convert comma-separated tuples into a list
        return parsed if isinstance(parsed, list) else []
    except Exception:
        return []

# Build prediction and gold lists
easy_predictions = []
easy_gold_data = []

for item in data:
    pred_sql = item.get("generated_sql", "")
    pred_result = parse_result_string(item.get("result", ""))

    gold_sql = item.get("gold_sql", "")
    gold_result = parse_result_string(item.get("gold_result", ""))

    easy_predictions.append((pred_sql, pred_result))
    easy_gold_data.append((gold_sql, gold_result))

# Check samples
print("Example prediction:", easy_predictions[0])
print("Example gold:", easy_gold_data[0])

In [None]:
easy_predictions

In [None]:
evaluate_predictions(easy_predictions, easy_gold_data)

## Logic Equivalence Check

In [None]:
import json

# Helper function
def extract_question_and_result(file_path):
    with open(file_path, "r") as f:
        lines = f.readlines()

    data = []
    current = {}

    for line in lines:
        line = line.strip()
        if line.startswith("Question:"):
            current["question"] = line.replace("Question:", "").strip()
        elif line.startswith("Execution Result:"):
            current["execution_result"] = line.replace("Execution Result:", "").strip()
            if "question" in current:  # Ensure both parts exist
                data.append(current)
            current = {}  # Reset for next instance

    return data

# Process each difficulty file
easy_data = extract_question_and_result("easy.txt")
medium_data = extract_question_and_result("medium.txt")
hard_data = extract_question_and_result("hard.txt")

# Write to JSON
with open("questions_easy.json", "w") as f:
    json.dump(easy_data, f, indent=2)

with open("questions_medium.json", "w") as f:
    json.dump(medium_data, f, indent=2)

with open("questions_hard.json", "w") as f:
    json.dump(hard_data, f, indent=2)

print("✅ JSON files written with question and execution_result.")

In [None]:
with open("questions_easy.json", "r") as f:
    easy_data = json.load(f)

with open("questions_medium.json", "r") as f:
    medium_data = json.load(f)

with open("questions_hard.json", "r") as f:
    hard_data = json.load(f)

print("Number of easy questions:", len(easy_data))
print("Number of medium questions:", len(medium_data))
print("Number of hard questions:", len(hard_data))

In [None]:
easy_execution=[]
medium_execution=[]
hard_execution=[]

for item in easy_predictions:
    easy_execution.append(item[1])

for item in medium_predictions:
    medium_execution.append(item[1])

for item in hard_predictions:
    hard_execution.append(item[1])

easy_execution=[str(item) for item in easy_execution]
medium_execution=[str(item) for item in medium_execution]
hard_execution=[str(item) for item in hard_execution]

In [None]:
custom_easy=[]
custom_medium=[]
custom_hard=[]

for item in easy_data:
    custom_easy.append(item["execution_result"].replace("{", "[").replace("}", "]"))

for item in medium_data:
    custom_medium.append(item["execution_result"].replace("{", "[").replace("}", "]"))

for item in hard_data:
    custom_hard.append(item["execution_result"].replace("{", "[").replace("}", "]"))

In [None]:
import ast

def normalize_result_string(result_str):
    """
    Converts a result string like "[('Alice',), ('Bob',)]" or "('Alice',), ('Bob',)" into a set of tuples.
    Handles empty strings and malformed inputs gracefully.
    """
    if not result_str.strip():
        return set()

    try:
        # Ensure it's parsable as a list: wrap if not
        if not result_str.strip().startswith("["):
            result_str = "[" + result_str + "]"

        parsed = ast.literal_eval(result_str)
        if isinstance(parsed, list):
            return set(tuple(str(x).strip() for x in row) for row in parsed)
    except Exception:
        pass

    return set()

def logic_equivalence_score(langchain_results, agentic_results):
    assert len(langchain_results) == len(agentic_results), "Mismatched result lengths"

    total = len(langchain_results)
    equivalent_count = 0

    for l_result, a_result in zip(langchain_results, agentic_results):
        if normalize_result_string(l_result) == normalize_result_string(a_result):
            equivalent_count += 1

    return equivalent_count / total if total > 0 else 0.0

In [None]:
logic_equivalence_score(easy_execution,custom_easy)

In [None]:
logic_equivalence_score(medium_execution,custom_medium)

In [None]:
logic_equivalence_score(hard_execution,custom_hard)