## Setup imports

In [1]:
import csv
import pandas as pd
from datasets import load_dataset
from database_interface import DatabaseHelper
from collections import Counter
MODEL_OUTPUT_PATH = "/scratch/eecs595f25_class_root/eecs595f25_class/llada_data/synthetic_text_to_sql/output_sql.csv"
OUTPUT_FILE_PATH = "/scratch/eecs595f25_class_root/eecs595f25_class/llada_data/synthetic_text_to_sql/valid_test.json"

## Load in initial dataset

In [2]:
retrieved_dataset = load_dataset("json", data_files=OUTPUT_FILE_PATH)["train"]
print(f"retrieved dataset, info={retrieved_dataset}")

retrieved dataset, info=Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 3434
})


## Load and merge the results with the examples

In [3]:
model_res_df = pd.read_csv(MODEL_OUTPUT_PATH)
gold_standard_df = retrieved_dataset.to_pandas()
gold_standard_df = gold_standard_df.iloc[0:len(model_res_df)][[
    'id',
    'sql_complexity',
    'sql',
    'sql_context',
    'sql_prompt',
    'sql_explanation'
    
]]

eval_df = pd.merge(model_res_df, gold_standard_df,on='id')
# print(eval_df)
print(eval_df.iloc[29])

id                                                                51
out_sql            ["\nSELECT COUNT(*) AS number_of_patientsFROM ...
sql_complexity                                           single join
sql                SELECT COUNT(*) FROM patients INNER JOIN treat...
sql_context        CREATE TABLE patients (id INT, country VARCHAR...
sql_prompt         What is the number of patients in India who re...
sql_explanation    First, we join the patients and treatments tab...
Name: 29, dtype: object


## Setup Bag-of-cells to do analysis of each data cell in outputs (Gemini provided)

In [4]:
def get_cell_counts(df_gold, df_predicted):
    """
    Calculates Precision, Recall, and F1 score based on a 
    "Bag of Cells" (multiset of all values).
    """
    
    # --- 1. Flatten Gold DataFrame into a Counter ---
    # .values.flatten() turns the whole DF into a 1D numpy array
    # We must handle potential 'None' values which Counter might not like
    gold_cells = Counter([str(v) for v in df_gold.values.flatten() if v is not None])
    
    # --- 2. Flatten Predicted DataFrame into a Counter ---
    predicted_cells = Counter([str(v) for v in df_predicted.values.flatten() if v is not None])

    # --- 3. Calculate Scores ---
    
    # Handle the "divide by zero" case if no cells are returned
    total_gold = sum(gold_cells.values())
    total_predicted = sum(predicted_cells.values())

    # & (intersection) finds the element-wise minimum
    intersection = gold_cells & predicted_cells
    intersection_count = sum(intersection.values())
    
    # precision = intersection_count / total_predicted
    # recall = intersection_count / total_gold
        
    return intersection_count, total_predicted, total_gold


#test
df_gold_1 = pd.DataFrame([[1.0, 2000.0], [1.0, 2000.0], [2.0, 4000.0], [2.0, 4000.0]])
df_pred_1 = pd.DataFrame([[1.0, 2000.0], [2.0, 4000.0]])
p, r, f1 = get_cell_counts(df_gold_1, df_pred_1)
print(f"Example 1 F1: {f1:.2f} (P={p:.2f}, R={r:.2f})")


df_gold_2 = pd.DataFrame([['Commercial', 70.0]])
df_pred_2 = pd.DataFrame([[70.0]])
p, r, f1 = get_cell_counts(df_gold_2, df_pred_2)
print(f"Example 2 F1: {f1:.2f} (P={p:.2f}, R={r:.2f})")

Example 1 F1: 8.00 (P=4.00, R=4.00)
Example 2 F1: 2.00 (P=1.00, R=1.00)


## Iterate over examples and process the results!

In [25]:
sql_difficulty = eval_df['sql_complexity'].unique()
complexity_eval_res = {complexity:{"correct":0, "parsable":0, "total":0, "cell_intersect":0, "cell_predict":0.000001, "cell_gold": 0.000001} for complexity in sql_difficulty}
print(complexity_eval_res)
count = 0
for index, row in eval_df.iterrows():
    model_sql_res = row.out_sql[1:len(row.out_sql)-1]
    complexity = row.sql_complexity
    complexity_eval_res[complexity]['total'] += 1
    if model_sql_res != "None":
        db_obj = DatabaseHelper(":memory:")
        db_obj.insert_data(row.sql_context)
        # Do minimal cleaning to pass the SQL
        cleaned_sql = model_sql_res.strip().strip("'").replace("\\n"," ").strip("\"")
        model_res, status = db_obj.fetch_data(cleaned_sql)
        
        # if status == "failure":
        #     continue
            
        # If SQL parsed, increment score
        complexity_eval_res[complexity]['parsable'] += 1
#         if model_res.empty:
#             # print(f"ID:{index} -- SQL that caused error:\n{model_sql_res}\n")
#             continue
            
        truth_res, status = db_obj.fetch_data(row.sql)
        # try:
        model_res_normalized = Counter(
            [tuple(sorted(row, key=str)) for row in model_res.values.tolist()]
        )
        truth_res_normalized = Counter(
            [tuple(sorted(row, key=str)) for row in truth_res.values.tolist()]
        )
        # except Exception as e:
        #     print("PARSING ISSUE")
        #     print(e)
        #     print("DATA:")
        #     print(truth_res)
        #     print(model_res)
        if model_res_normalized != truth_res_normalized and complexity == "subqueries":
            print(f"DEBUG:\n Model res:\n{model_res_normalized}\nTruth res:\n{truth_res_normalized}\n")
            print(f"Prompt:{row.sql_prompt}")
            print(f"SQL of res: {cleaned_sql}")
            print(f"SQL of truth: {row.sql}")
            print(f"SQL explanation: {row.sql_explanation}")
            print("---------")
        # Give correct point for results that are matched
        # Note: This is a naive method which doesn't consider ORDER by. We will assume the non-ordered data considered 'correct'
        
        complexity_eval_res[complexity]['correct'] += 1 if model_res_normalized == truth_res_normalized else 0
        # get cell matches
        intersect_count, predict_count, gold_count = get_cell_counts(truth_res, model_res)
    
        # if complexity == "subqueries":
        #     print("MODEL RES")
        #     print(model_res)
        #     print("TRUTH RES")
        #     print(truth_res)
        #     print(get_cell_counts(truth_res, model_res))
        #     count += 1
        #     print("number of subqueires considered:",count)
        complexity_eval_res[complexity]['cell_intersect'] += intersect_count
        complexity_eval_res[complexity]['cell_predict'] += predict_count
        complexity_eval_res[complexity]['cell_gold'] += gold_count
        
    
        
        

{'basic SQL': {'correct': 0, 'parsable': 0, 'total': 0, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 1e-06}, 'aggregation': {'correct': 0, 'parsable': 0, 'total': 0, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 1e-06}, 'window functions': {'correct': 0, 'parsable': 0, 'total': 0, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 1e-06}, 'single join': {'correct': 0, 'parsable': 0, 'total': 0, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 1e-06}, 'multiple_joins': {'correct': 0, 'parsable': 0, 'total': 0, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 1e-06}, 'subqueries': {'correct': 0, 'parsable': 0, 'total': 0, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 1e-06}, 'set operations': {'correct': 0, 'parsable': 0, 'total': 0, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 1e-06}}
DEBUG:
 Model res:
Counter()
Truth res:
Counter({('Africa', nan): 1, ('Asia', nan): 1, (0.5, 'Europe'): 1})

Prompt:What is the differenc

In [26]:
print(complexity_eval_res)

{'basic SQL': {'correct': 120, 'parsable': 157, 'total': 577, 'cell_intersect': 230, 'cell_predict': 271.000001, 'cell_gold': 305.000001}, 'aggregation': {'correct': 39, 'parsable': 57, 'total': 234, 'cell_intersect': 242, 'cell_predict': 277.000001, 'cell_gold': 319.000001}, 'window functions': {'correct': 1, 'parsable': 12, 'total': 33, 'cell_intersect': 60, 'cell_predict': 88.000001, 'cell_gold': 115.000001}, 'single join': {'correct': 26, 'parsable': 45, 'total': 134, 'cell_intersect': 76, 'cell_predict': 109.000001, 'cell_gold': 115.000001}, 'multiple_joins': {'correct': 1, 'parsable': 6, 'total': 18, 'cell_intersect': 2, 'cell_predict': 2.000001, 'cell_gold': 29.000000999999997}, 'subqueries': {'correct': 3, 'parsable': 11, 'total': 48, 'cell_intersect': 4, 'cell_predict': 4.000001, 'cell_gold': 35.000001}, 'set operations': {'correct': 0, 'parsable': 2, 'total': 7, 'cell_intersect': 0, 'cell_predict': 1e-06, 'cell_gold': 12.000001000000001}}


In [29]:
for complexity, counts in complexity_eval_res.items():
    print(f"Complexity: {complexity}, num. samples = {counts['total']}")
    print(f"   Macro stats -- correct:{counts['correct']/counts['total']:.2} | parsable: {counts['parsable']/counts['total']:.2}")
    precision = counts['cell_intersect']/counts['cell_predict']
    recall = counts['cell_intersect']/counts['cell_gold']
    print(f"   Cell stats -- precision:{precision:.2} | recall:{recall:.2} | F1: {(2 * precision * recall) / (precision + recall+1e-9)}")
    # precision = intersection_count / total_predicted
    #recall = intersection_count / total_gold

Complexity: basic SQL, num. samples = 577
   Macro stats -- correct:0.21 | parsable: 0.27
   Cell stats -- precision:0.85 | recall:0.75 | F1: 0.7986111078398981
Complexity: aggregation, num. samples = 234
   Macro stats -- correct:0.17 | parsable: 0.24
   Cell stats -- precision:0.87 | recall:0.76 | F1: 0.8120805336901321
Complexity: window functions, num. samples = 33
   Macro stats -- correct:0.03 | parsable: 0.36
   Cell stats -- precision:0.68 | recall:0.52 | F1: 0.5911329986109831
Complexity: single join, num. samples = 134
   Macro stats -- correct:0.19 | parsable: 0.34
   Cell stats -- precision:0.7 | recall:0.66 | F1: 0.6785714220131138
Complexity: multiple_joins, num. samples = 18
   Macro stats -- correct:0.056 | parsable: 0.33
   Cell stats -- precision:1.0 | recall:0.069 | F1: 0.1290322496191472
Complexity: subqueries, num. samples = 48
   Macro stats -- correct:0.062 | parsable: 0.23
   Cell stats -- precision:1.0 | recall:0.11 | F1: 0.2051281944247211
Complexity: set oper