In [1]:
import pandas as pd
import time
import subprocess
import os
import sys
import re
import ast
from itertools import permutations
from tqdm.auto import tqdm
tqdm.pandas()

# clean the query
def clean_mql_query(query):
    pattern = re.compile(r'[\s\S]*?(?=\[Q)')
    result = pattern.search(query)
    if result:
        result = result.group()
    else:
        result = query
    return result.replace("[MongoDB]:","").strip()

# formulate the queries
def add_quotes_to_keys(s):
    # add space after "}" and before "}"
    s = re.sub(r'({)', r'\1 ', str(s))
    s = re.sub(r'(})', r' \1', str(s))
    s = re.sub(r'\{\s*(\w+)\s*:', r"{ '\1':", str(s))
    s = re.sub(r',\s*(\w+)\s*:', r", '\1':", str(s))
    res = re.sub(r"Long\('(\w+)'\)", r'\1', str(s))
    res = res.strip()
    # trim the outer "()" if exists
    if res[0] == "(":
        if res[-1] == ")":
            res = res[1:-1]
        else:
            res = res[1:]
    # trim the last "," if exists
    if res[-1] == ',':
        res = res[:-1]
    return f"[{res}]"

# string to python object
def str_to_obj(s):
    try:
        return ast.literal_eval(add_quotes_to_keys(s).replace("null", "None"))
    except (ValueError, SyntaxError) as e:
        #if failed return s direct
        return s

# func to 
def dicts_to_tuples(list_of_dicts):
    try:
        list_of_dicts = [tuple(d.values()) for d in list_of_dicts]
    except AttributeError:
        raise ValueError(f"Input must be a list of dictionaries {list_of_dicts}")
    finally:
        return list_of_dicts

def permutation_list(list1):
    columns = list(zip(*list1))
    perm_columns = permutations(columns)
    result = []
    for perm in perm_columns:
        permuted_list = [tuple(item) for item in zip(*perm)]
        result.append(permuted_list)
    return result

def compare_list(list1, list2, row_permutation=False, column_permutation=False):
    if len(list1) != len(list2):
        return False
    if row_permutation:
        def _compare_list(list1, list2):
            for item1 in list1:
                if item1 not in list2:
                    return False
            return True
    else:
        def _compare_list(list1, list2):
            return list1 == list2
    if column_permutation:
        permuted_list1 = permutation_list(list1)
        for permuted in permuted_list1:
            if _compare_list(permuted, list2):
                return True
    
    return _compare_list(list1, list2)

def compare(s1, s2):
    o1 = str_to_obj(s1)
    o2 = str_to_obj(s2)
    if type(o1) == list and type(o2) == list:
        try:
            t1 = dicts_to_tuples(o1)
            t2 = dicts_to_tuples(o2)
            return compare_list(t1, t2, True, True)
        except Exception as e:
            return False

    return s1.lower().strip() == s2.lower().strip()



In [2]:
def execute_mongo_query(query:str):
    if query.lower().strip()[:9] == "no answer":
        return query.strip(), 0.0
    start_time = time.time()
    records = "[]"
    try:
        process = subprocess.Popen(["mongosh", "--quiet", "healthcare"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        stdout, stderr = process.communicate(input=query)
        
        # Check if there was an error
        if stderr:
            records = f"Error: {stderr.strip()}"
        else:
            # Check if stdout contains a MongoDB error message
            error_pattern = re.compile(r'Error: .*')
            error_match = error_pattern.search(stdout.strip())
            if error_match:
                error_message = error_match.group()
                records = error_message
            else:
                # Extract the results from the output
                results_pattern = re.compile(r'\[([\s\S]*?)\]')
                results_match = results_pattern.search(stdout.strip())
                if results_match:
                    records = results_match.group(1)
                    records = re.sub(r'\s+', ' ', records).strip()
    except Exception as e:
        records = f"Error: {str(e)}"
    finally:
        end_time = time.time()
        execution_time = end_time - start_time
        return records, execution_time

In [5]:
from pathlib import Path
root_path = Path.cwd()
root_path = root_path.parents[2]
# load files:
eval_file = root_path / "data/results/mql" / "EXP7/dev/dev_mql_prompt_schema_bm25_with_template_gpt-3.5-turbo-0125-batch.csv"
expert_file = root_path / "data/dataset/processed_data" / "processed_dev.csv"
output_file = root_path / "data/results/mql" / "EXP7/dev/dev_mql_prompt_schema_bm25_with_template_gpt-3.5-turbo-0125-batch-checkpoint.csv"

In [6]:

eval_df = pd.read_csv(eval_file)
expert_df = pd.read_csv(expert_file)

eval_df["cleaned_mql_llm_query"] = eval_df.progress_apply(lambda x: clean_mql_query(x['answers']), axis=1)
eval_df["mql_llm_results"], eval_df["mql_llm_time"] = zip(*eval_df['cleaned_mql_llm_query'].progress_apply(execute_mongo_query))



  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

In [7]:
total_df = pd.concat([expert_df, eval_df[["cleaned_mql_llm_query", "mql_llm_results", "mql_llm_time"]]], axis=1)
total_df["mql_label"] = total_df.progress_apply(lambda row: compare(row["mql_results"], row["mql_llm_results"]), axis=1)
total_df.to_csv(output_file, index=False)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [8]:
len(total_df[total_df["mql_label"]])/len(total_df)

0.823

In [9]:
import math

def compute_ves(label, time_gt, time_pred):
    if time_pred == 0:
        return 0
    ves = int(label) * math.sqrt(time_gt/time_pred)
    return ves

In [10]:
(total_df.progress_apply(lambda x: compute_ves(x["mql_label"], x["mql_query_time"], x["mql_llm_time"]), axis=1).sum()/len(total_df)).item()

  0%|          | 0/2000 [00:00<?, ?it/s]

0.839861737617344