In [1]:
%pip install pickle-mixin openpyxl pandas tqdm collections-extended

Note: you may need to restart the kernel to use updated packages.


In [2]:
import re
import json
import pickle
import random
import openpyxl
import itertools
import pandas as pd
from datetime import datetime
from tqdm.autonotebook import tqdm
from collections import defaultdict

  from tqdm.autonotebook import tqdm


#### 定义输入输出路径，并加载数据
- ```dataset_dir```：标注好的数据集表格路径
- ```statistics_output_dir```：统计数据表输出路径，包括属性值、选择率、基数
- ```valid_where_output_dir```：所有有效谓词组合的输出路径
- **注意**：区分标注好的表格是 .csv 还是 .xlsx 格式，相应改变 pandas 读取形式


In [3]:
# dataset_dir = r"/home/sunzhaoze/Benchmark/LCR_All_Attr.csv"
dataset_dir = r"/data2/liujinqi/Benchmark/Query/LCR_AutoConstruct/LCR_All_Attr.xlsx"
statistics_output_dir = r"/data2/liujinqi/Benchmark/Query/LCR_AutoConstruct/LCR_Statistics.csv"
valid_where_output_dir = r"/data2/liujinqi/Benchmark/Query/LCR_AutoConstruct/valid_WHERE.json"
# df = pd.read_csv(dataset_dir)
df = pd.read_excel(dataset_dir)

#### 定义属性，方便后面按不同属性类型设计不同的构造方法
- ```attr_desc_dict```：全部属性的集合，以及对应的自然语言描述
- ```non_numerical_attr```：非数值属性的集合
- ```numerical_attr```：数值属性的集合
- ```non_formatted_attr```：非格式化数值属性
- ```formatted_attr```：格式化数值属性，如日期
- ```category_attr```：固定类别的属性

In [4]:
attr_desc_dict = {
     "judge_name": "",
     "plaintiff": "",
     "defendant": "",
     "hearing_year": "",
     "judgment_year": "",
     "charges": "",
     "case_type": "",
     "verdict": "",
     "legal_basis_num": "",
     "case_num": "",
     "counsel_for_applicant": "",
     "counsel_for_respondent": "",
     "nationality_for_applicant": "",
     "fine_amount": "",
     "legal_fees": "",
     "plaintiff_current_status": "",
     "defendant_current_status": "",
     "evidence": "",
     "first_judge": "",
}
non_numerical_attr = ["judge_name", "plaintiff", "defendant", "charges", 
                      "case_type", "verdict", "counsel_for_applicant", "counsel_for_respondent", 
                      "nationality_for_applicant",  "plaintiff_current_status" 
                      "defendant_current_status", "evidence", "first_judge"]
numerical_attr = ["hearing_year", "judgment_year", "legal_fees", "legal_basis_num", "case_num", "fine_amount"]
non_formatted_attr = ["legal_fees", "legal_basis_num", "case_num", "fine_amount"]
formatted_attr = ["hearing_year", "judgment_year"]
category_attr = ["verdict",  "case_type", "evidence", "first_judge"]
multi_value_attributes = []

#### 生成统计信息
- 属性 | 属性值 | 选择率 | 基数
- 用于后续构造 Filter

In [5]:
statistics = pd.DataFrame()

for column in attr_desc_dict.keys():
    if column in multi_value_attributes:
        expanded_values = []
        original_row_count = {} 
        
        for idx, value in enumerate(df[column]):
            if pd.isna(value):
                expanded_values.append(None)
                original_row_count[None] = original_row_count.get(None, 0) + 1
            elif isinstance(value, str) and '||' in value:
                split_values = [v.strip() for v in value.split('||') if v.strip()]
                expanded_values.extend(split_values)
                for split_val in split_values:
                    original_row_count[split_val] = original_row_count.get(split_val, 0) + 1
            else:
                expanded_values.append(value)
                original_row_count[value] = original_row_count.get(value, 0) + 1
        
        expanded_series = pd.Series(expanded_values)
        value_counts = expanded_series.value_counts()
        
        selectivities = pd.Series({
            k: round(v / len(df), 3) for k, v in original_row_count.items()
        })
        
    else:
        value_counts = df[column].value_counts()
        selectivities = df[column].value_counts(normalize=True).round(3)
    
    null_count = df[column].isnull().sum()
    
    result_df = pd.DataFrame({
        f"{column}": list(value_counts.index) + ["(null)"],
        'Count': list(value_counts.values) + [null_count],
        'Selectivity': [selectivities.get(val, 0) for val in value_counts.index] + [round(null_count / len(df), 3)]
    })
    
    if statistics.empty:
        statistics = result_df
    else:
        statistics = pd.concat([statistics, result_df], axis=1)

statistics.to_csv(statistics_output_dir, index=False)
statistics

Unnamed: 0,judge_name,Count,Selectivity,plaintiff,Count.1,Selectivity.1,defendant,Count.2,Selectivity.2,hearing_year,...,Selectivity.3,defendant_current_status,Count.3,Selectivity.4,evidence,Count.4,Selectivity.5,first_judge,Count.5,Selectivity.6
0,Tracey,31.0,0.052,Australian Competition and Consumer Commission,15.0,0.025,Minister for Immigration and Citizenship,160.0,0.269,NaT,...,0.370,Government,306.0,0.567,1,467.0,0.778,0,319.0,0.532
1,Flick,29.0,0.048,Australian Securities and Investments Commission,5.0,0.008,Minister for Immigration and Multicultural Aff...,51.0,0.086,2006-11-06,...,0.096,Company,137.0,0.254,0,132.0,0.220,1,281.0,0.468
2,Greenwood,26.0,0.043,Cadbury Schweppes Pty Ltd,5.0,0.008,Minister for Immigration and Multicultural and...,26.0,0.044,2006-03-09,...,0.057,Organization,55.0,0.102,Not explicitly mentioned,1.0,0.002,(null),0.0,0.000
3,Collier,18.0,0.030,Commissioner of Taxation,4.0,0.007,Commissioner of Taxation,12.0,0.020,2008-05-19,...,0.045,Employee,7.0,0.013,(null),0.0,0.000,,,
4,Spender,18.0,0.030,Australian Prudential Regulation Authority,3.0,0.005,Repatriation Commission,7.0,0.012,2008-05-20,...,0.039,Chief Executive Officer,2.0,0.004,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
543,,,,,,,,,,,...,,,,,,,,,,
544,,,,,,,,,,,...,,,,,,,,,,
545,,,,,,,,,,,...,,,,,,,,,,
546,,,,,,,,,,,...,,,,,,,,,,


#### 汇总可取属性值，以便后续据此构造 Filter
- ```attr_value_dict```：所有可取属性值的集合，每个属性有多个可取属性值，根据统计信息来定

In [6]:
attr_value_dict = {
     "judge_name": ["Tracey", "Flick", "Greenwood", "Collier", "Spender", "Besanko", "Moore", "Marshall", "Heerey", "Rares", "Bennett", "Jagot", "Cowdroy", "Conti", "Middleton", "Siopis", "Kenny", "Gyles", "Weinberg", "Graham", "French", "Gilmour", "Lander", "Gray", "Logan", "Gordon", "Finkelstein", "Mckerracher", "Edmonds", "Tamberlin", "Lindgren", "Buchanan"],
     "plaintiff": ["Australian Competition and Consumer Commission", "Australian Securities and Investments Commission", "Cadbury Schweppes Pty Ltd", "Commissioner of Taxation"],
     "defendant": ["Minister for Immigration and Citizenship", "Minister for Immigration and Multicultural Affairs", "Minister for Immigration and Multicultural and Indigenous Affairs", ],
     "charges": [],
     "case_type": ["Administrative Case", "Civil Case", "Commercial Case", "Criminal Case"],
     "verdict": ["Dismissed", "Approved", "Guilty", "Others"],
     "counsel_for_applicant": [],
     "counsel_for_respondent": [],
     "nationality_for_applicant": ["Australia", "China", "India", "Fiji", "Pakistan"],
     "plaintiff_current_status": ["Company", "Organization", "Government"],
     "defendant_current_status": ["Government", "Company", "Organization"],
     "evidence": ["1", "0"],
     "first_judge": ["1", "0"],
     "hearing_year": ["2005", "2006", "2007", "2009", "2005-3", "2005-4", "2005-7", "2005-10", "2005-11", "2006-1", "2006-4", "2006-5", "2006-6", "2006-8","2006-10", "2006-12", "2007-2", "2007-3", "2007-4", "2007-5", "2007-8", "2007-9", "2007-10", "2007-11", "2007-12", "2008-1", "2008-5", "2008-7", "2008-9", "2008-12", "2009-4", "2009-5", "2009-10", "2009-11", "2008-5-20", "2007-5-2", "2006-11-13", "2006-3-9"],
     "judgment_year": ["2005", "2006", "2008", "2009", "2005-2", "2005-3", "2005-4", "2005-8", "2005-9", "2005-10", "2005-11", "2005-12", "2006-1", "2006-2", "2006-3", "2006-4", "2006-5", "2006-6", "2006-7", "2006-8", "2006-9", "2006-10", "2006-11", "2006-12", "2007-1", "2007-2", "2007-3", "2007-7", "2007-8", "2007-10", "2007-12", "2008-3", "2008-8", "2008-12", "2009-1", "2009-5", "2009-6", "2009-9"],
     "legal_fees": [200000, 100000, 10000, 5000, 2000, 1200, 1000, 700, 300, 0],
     "legal_basis_num": [20, 10, 6, 3, 1, 0],
     "case_num": [50, 20, 10, 5, 2, 0],
     "fine_amount": [1000000, 100000, 50000, 10000, 5000, 1000, 100, 0]
}

#### 定义查询构造数量
- ```max_filters```：WHERE 语句最大 Filter 数量
- ```min_rows```：结果表最少行数
- ```max_select```：SELECT 语句最大属性数量
- ```limit_list```：LIMIT 语句候选值
- ```sample_sfw```：SELECT | FROM | WHERE 查询数量
- ```sample_sfwt```：SELECT | FROM | WHERE | TOP-K 查询数量
- ```sample_sfwg```：SELECT | FROM | WHERE | GROUP BY 查询数量
- ```sample_sfwa```：SELECT | FROM | WHERE | AGGREGATION 查询数量
- ```sample_sfwga```：SELECT | FROM | WHERE | GROUP BY | AGGREGATION 查询数量
- ```sample_sfwgat```：SELECT | FROM | WHERE | GROUP BY | AGGREGATION | TOP-K 查询数量

In [7]:
max_filters = 5
min_rows = 5
max_select = 3
limit_list = [1, 5, 10, 20, 50]
sample_sf = 10
sample_sfw = 10
sample_sfwt = 20
sample_sfwg = 20
sample_sfwa = 20
sample_sfwga = 30
sample_sfwgat = 20

#### 定义 Filter 具体执行方法
- 不同类型的数据有不同的执行方法
- 区分非数值型、数值型、日期型
- 比较运算包括：大于（只限于数值型）、小于（只限于数值型）、等于（任意类型）
- **注意**：需要提前确认好日期型数据的具体格式，做好特殊情况相应处理

In [8]:
def non_numerical_equal_to(value, condition):
     #print(f'value:{value}, condition:{condition}')
     try:
          value = str(value).lower().strip()
          condition = str(condition).lower().strip()
     except:
          print("Invalid data (non_numerical_equal_to). value:[%s] | condition: [%s]." % (value, condition))
     return value == condition

def number_greater_than(value, condition):
     if pd.isna(value):
          value = 0.00
     try:
          value = round(float(value), 2)
          condition = round(float(condition), 2)
     except:
          print("Invalid data (number_greater_than). value:[%s] | condition: [%s]." % (value, condition))
     return value > condition

def number_less_than(value, condition):
     if pd.isna(value):
          value = 0.00
     try:
          value = round(float(value), 2)
          condition = round(float(condition), 2)
     except:
          print("Invalid data (number_less_than). value:[%s] | condition: [%s]." % (value, condition))
     return value < condition

def number_equal_to(value, condition):
     if pd.isna(value):
          value = 0.00
     try:
          value = round(float(value), 2)
          condition = round(float(condition), 2)
     except:
          print("Invalid data (number_equal_to). value:[%s] | condition: [%s]." % (value, condition))
     return value == condition

def parse_date(date_str):
     if pd.isna(date_str):
          return None
     for fmt in ['%Y/%m/%d', '%Y/%m', '%Y','%Y-%m-%d', '%Y-%m']:
          try:
               return datetime.strptime(date_str, fmt)
          except ValueError:
               continue
     return None

def date_greater_than(value, condition):
     date_value = parse_date(value)
     condition_date = parse_date(condition)
     if date_value is None or condition_date is None:
          return False
     return date_value > condition_date

def date_less_than(value, condition):
     date_value = parse_date(value)
     condition_date = parse_date(condition)
     if date_value is None or condition_date is None:
          return False
     return date_value < condition_date

def date_equal_to(value, condition):
     date_value = parse_date(value)
     condition_date = parse_date(condition)
     if date_value is None or condition_date is None:
          return False
     return date_value == condition_date

def parse_century(value):
    """将世纪字符串解析为 (起始年份, 结束年份) 的元组。

    优雅地处理 NaN 值和不正确的格式。
    """
    if pd.isna(value):
        return None

    try:
        match = re.match(r"(\d+)(?:th|st|nd|rd)(?:-(\d+)(?:th|st|nd|rd))?", value)
        if not match:
            print(f"无效的世纪格式: {value}")  # 更具体的错误信息
            return None

        start_century = int(match.group(1))
        end_century = int(match.group(2)) if match.group(2) else start_century
        return (start_century - 1) * 100, end_century * 100

    except ValueError:
        print(f"无效的世纪格式 (ValueError): {value}")
        return None


def century_greater_than(value, condition):
    """检查一个世纪值是否大于另一个世纪值。

    优雅地处理 NaN 值和不正确的格式。
    """
    value_parsed = parse_century(value)
    condition_parsed = parse_century(condition)

    if value_parsed is None or condition_parsed is None:
        return False

    return value_parsed[0] > condition_parsed[1]

def century_less_than(value, condition):
    """检查一个世纪值是否大于另一个世纪值。

    优雅地处理 NaN 值和不正确的格式。
    """
    value_parsed = parse_century(value)
    condition_parsed = parse_century(condition)

    if value_parsed is None or condition_parsed is None:
        return False

    return value_parsed[0] < condition_parsed[1]

def number_greater_equal(value, condition):
    if pd.isna(value):
        value = 0.00
    try:
        value = round(float(value), 2)
        condition = round(float(condition), 2)
    except:
        print("Invalid data (number_greater_equal). value:[%s] | condition: [%s]." % (value, condition))
    return value >= condition

def number_less_equal(value, condition):
    if pd.isna(value):
        value = 0.00
    try:
        value = round(float(value), 2)
        condition = round(float(condition), 2)
    except:
        print("Invalid data (number_less_equal). value:[%s] | condition: [%s]." % (value, condition))
    return value <= condition

def date_greater_equal(value, condition):
    date_value = parse_date(value)
    condition_date = parse_date(condition)
    if date_value is None or condition_date is None:
        return False
    return date_value >= condition_date

def date_less_equal(value, condition):
    date_value = parse_date(value)
    condition_date = parse_date(condition)
    if date_value is None or condition_date is None:
        return False
    return date_value <= condition_date


#### 枚举所有可能的 Filter
- Filter 的可能取值来自于字典：```attr_value_dict```
- 用一个列表记录满足 Filter 的行索引
- 输入：```dataset_dir```，即原始标注数据表
- 输出：```filter_dict.json```，包含各属性，每个属性中包含所有可能的过滤条件及对应满足条件的行索引号
- 这一步构建好所有可能的 Filter，后面从此 Filter 集合中选取 Filter 进行排列组合，构造 WHERE 语句

In [9]:
df = pd.read_excel(dataset_dir)

filter_dict = {}

for key in attr_value_dict.keys():
    filter_dict[key] = {}

for key, value in tqdm(filter_dict.items()):
    condition_dict = {}

    ###### 非数值属性，只取等于操作 ######
    if key in non_numerical_attr:
        for possible_value in attr_value_dict[key]:
            result = df[key].apply(non_numerical_equal_to, condition=possible_value)
            result_indices = df[result].index.tolist()
            condition_dict["=='" + str(possible_value) + "'"] = result_indices

    ###### 数值属性，取5种操作：<, >, =, <=, >= ######
    if key in numerical_attr:
        for possible_value in attr_value_dict[key]:
            
            # 初始化所有结果索引列表
            result_indices_less = []
            result_indices_greater = []
            result_indices_equal = []
            result_indices_less_equal = []
            result_indices_greater_equal = []

            ###### 非格式化数值属性（普通数值）######
            if key in non_formatted_attr:
                result_index_less = df[key].apply(number_less_than, condition=possible_value)
                result_index_greater = df[key].apply(number_greater_than, condition=possible_value)
                result_index_equal = df[key].apply(number_equal_to, condition=possible_value)
                result_index_less_equal = df[key].apply(number_less_equal, condition=possible_value)
                result_index_greater_equal = df[key].apply(number_greater_equal, condition=possible_value)
                
                result_indices_less = df[result_index_less].index.tolist()
                result_indices_greater = df[result_index_greater].index.tolist()
                result_indices_equal = df[result_index_equal].index.tolist()
                result_indices_less_equal = df[result_index_less_equal].index.tolist()
                result_indices_greater_equal = df[result_index_greater_equal].index.tolist()

            ###### 格式化数值属性（日期）######
            elif key in formatted_attr:
                result_index_less = df[key].apply(date_less_than, condition=possible_value)
                result_index_greater = df[key].apply(date_greater_than, condition=possible_value)
                result_index_equal = df[key].apply(date_equal_to, condition=possible_value)
                result_index_less_equal = df[key].apply(date_less_equal, condition=possible_value)
                result_index_greater_equal = df[key].apply(date_greater_equal, condition=possible_value)
                
                result_indices_less = df[result_index_less].index.tolist()
                result_indices_greater = df[result_index_greater].index.tolist()
                result_indices_equal = df[result_index_equal].index.tolist()
                result_indices_less_equal = df[result_index_less_equal].index.tolist()
                result_indices_greater_equal = df[result_index_greater_equal].index.tolist()

            # 将所有5种操作添加到条件字典中
            condition_dict["<" + str(possible_value)] = result_indices_less
            condition_dict[">" + str(possible_value)] = result_indices_greater
            condition_dict["==" + str(possible_value)] = result_indices_equal
            condition_dict["<=" + str(possible_value)] = result_indices_less_equal
            condition_dict[">=" + str(possible_value)] = result_indices_greater_equal
    
    filter_dict[key] = condition_dict

with open("./filter_dict.json", 'w') as f:
    json.dump(filter_dict, f, ensure_ascii=False)

  0%|          | 0/19 [00:00<?, ?it/s]

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- ```balanced_sample()```：控制均匀采样 Filter
     - 保证每个属性出现的次数是平衡的

In [10]:
###### 均匀采样函数 ######
def balanced_sample(filters, sample_size=10, random_seed=None):
    if random_seed is not None:
        random.seed(random_seed)
    
    col_to_filters = defaultdict(list)
    for filter_item in filters:
        col = filter_item[0]
        col_to_filters[col].append(filter_item)
    
    unique_cols = list(col_to_filters.keys())
    num_cols = len(unique_cols)
    
    if num_cols == 0:
        return []
    
    base_num = sample_size // num_cols
    remainder = sample_size % num_cols
    
    col_sample_counts = {col: base_num for col in unique_cols}

    for col in random.sample(unique_cols, remainder):
        col_sample_counts[col] += 1
    
    sampled_filters = []
    overflow = 0
    
    for col in unique_cols:
        available = len(col_to_filters[col])
        required = col_sample_counts[col]
        if available >= required:
            sampled = random.sample(col_to_filters[col], required)
            sampled_filters.extend(sampled)
        else:
            sampled = col_to_filters[col]
            sampled_filters.extend(sampled)
            overflow += (required - available)
    
    if overflow > 0:
        remaining_cols = [col for col in unique_cols if len(col_to_filters[col]) > col_sample_counts[col]]
        while overflow > 0 and remaining_cols:
            for col in remaining_cols.copy():
                available = len(col_to_filters[col])
                current = col_sample_counts[col]
                if available > current:
                    sampled = random.sample(
                        list(set(col_to_filters[col]) - set(sampled_filters)), 1
                    )
                    sampled_filters.extend(sampled)
                    col_sample_counts[col] += 1
                    overflow -= 1
                    if overflow == 0:
                        break
                else:
                    remaining_cols.remove(col)

    return sampled_filters[:sample_size]

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- 枚举所有可能的 Filter 排列组合

In [11]:
with open("./filter_dict.json", 'r') as f:
     filter_dict = json.load(f)

filters = []
for col, conditions in filter_dict.items():
     for cond, indices in conditions.items():
          filters.append((col, cond, set(indices)))

# filters = random.sample(filters, 20)
filters = balanced_sample(filters, sample_size=20, random_seed=42)

all_combinations = []
for n in tqdm(range(1, max_filters + 1)):
     all_combinations.extend(itertools.permutations(filters, n))

with open("./all_combinations.pkl", "wb") as f:
    pickle.dump(all_combinations, f)

  0%|          | 0/5 [00:00<?, ?it/s]

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- ```select_combinations_with_ratio()```：按不同 Filter 数量的比例构造 WHERE 语句
- 1 - 5 个 Filter 的比例为：2 : 3 : 3 : 1 : 1

In [12]:
def select_combinations_with_ratio(all_combinations, total=1000, ratio=[2, 3, 3, 1, 1], min_len=1, max_len=5):

    length_to_combinations = defaultdict(list)
    for combo in all_combinations:
        l = len(combo)
        if min_len <= l <= max_len:
            length_to_combinations[l].append(combo)

    total_ratio = sum(ratio)
    
    samples_per_length = {}
    remaining_total = total
    
    for i, l in enumerate(range(min_len, max_len + 1)):
        expected_samples = int(total * ratio[i] / total_ratio)
        available_samples = len(length_to_combinations[l])
        if available_samples < expected_samples:
            samples_per_length[l] = available_samples
            remaining_total -= available_samples
            ratio[i] = 0
        else:
            samples_per_length[l] = expected_samples
            remaining_total -= expected_samples
    
    if remaining_total > 0:
        for i, l in enumerate(range(min_len, max_len + 1)):
            if ratio[i] > 0:
                possible_to_add = len(length_to_combinations[l]) - samples_per_length[l]
                if possible_to_add > 0:
                    additional_samples = min(possible_to_add, remaining_total)
                    samples_per_length[l] += additional_samples
                    remaining_total -= additional_samples
                if remaining_total == 0:
                    break

    if remaining_total != 0:
        raise ValueError(f"The number of combinations cannot meet the requirements, with {remaining_total} combinations remaining.")
    
    selected = []
    for l in range(min_len, max_len + 1):
        available = length_to_combinations[l]
        required = samples_per_length[l]
        selected += random.sample(available, required)
    
    random.shuffle(selected)
    return selected

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- 析取 + 合取 + 析取与合取混合
- 去掉等效表达式
- 每个查询的结果表至少包含 ```min_rows``` 行 

In [13]:
with open("./all_combinations.pkl", "rb") as f:
     all_combinations = pickle.load(f)

def normalize_expression(expression):

    def sort_expression(expr):
        if " AND " in expr:
            parts = sorted(expr.strip("()").split(" AND "))
            return f"({' AND '.join(parts)})"
        elif " OR " in expr:
            parts = sorted(expr.strip("()").split(" OR "))
            return f"({' OR '.join(parts)})"
        return expr
    
    if " AND " in expression or " OR " in expression:
        return sort_expression(expression)
    return expression

valid_where = []
seen_expressions = set()

# sampled_combinations = random.sample(all_combinations, 1000)
sampled_combinations = select_combinations_with_ratio(all_combinations, total=1000, ratio=[2, 3, 3, 1, 1], min_len=1, max_len=max_filters)

for combo in tqdm(sampled_combinations):
    condition_sets = [set(item[2]) for item in combo]
    predicates = [f"{col}{cond}" for col, cond, _ in combo]

    for op in itertools.product(["and", "or"], repeat=len(condition_sets) - 1):
        # 使用更直接的逻辑计算
        result_set = condition_sets[0].copy()  # 从第一个集合开始
        result_expr = predicates[0]
        
        # 依次应用每个操作符
        for i, logic in enumerate(op):
            next_set = condition_sets[i + 1]
            next_pred = predicates[i + 1]
            
            if logic == "and":
                result_set = result_set & next_set
                result_expr = f"({result_expr} AND {next_pred})"
            elif logic == "or":
                result_set = result_set | next_set
                result_expr = f"({result_expr} OR {next_pred})"
        
        # 检查结果
        if len(result_set) >= min_rows:
            normalized_expression = normalize_expression(result_expr)
            if normalized_expression not in seen_expressions:
                seen_expressions.add(normalized_expression)
                
                combo_list = [[c[0], c[1]] for c in combo]
                query_dict = {
                    "WHERE Indices": list(result_set),
                    "WHERE Total Rows": len(result_set),
                    "Combination": combo_list,
                    "Operators": list(op),
                    "WHERE": result_expr
                }
                valid_where.append(query_dict)
                # if len(valid_where) > 1000:
                #         with open("/data/sunzhaoze/benchmark/valid_WHERE.json", 'a') as f:
                #             json.dump(valid_where, f, ensure_ascii=False, indent=4, separators=(',', ': '))
                #         valid_where.clear()

    if valid_where:
        with open(valid_where_output_dir, 'w') as f:
            json.dump(valid_where, f, ensure_ascii=False)

  0%|          | 0/1000 [00:00<?, ?it/s]

#### 定义表头 SCHEMA
- ```create_schema()```：SCHEMA 定义函数，数据共三种类型：VARCHAR(255) | DATE | FLOAT

In [14]:
def create_schema(attr_desc_dict, query_attr_list):
     schema = {}
     difference = set(attr_desc_dict.keys()) - set(query_attr_list)
     redundant_attr_list = random.sample(list(difference), random.randint(0, 4))
     schema_attr_list = query_attr_list + redundant_attr_list
     for key in schema_attr_list:
          if key in non_numerical_attr:
               schema[key] = ["VARCHAR(255)", attr_desc_dict[key]]
          elif key in formatted_attr:
               schema[key] = ["DATE", attr_desc_dict[key]]
          elif key in non_formatted_attr:
               schema[key] = ["FLOAT", attr_desc_dict[key]]
     return schema

#### json 自定义格式化输出

In [15]:
def custom_json_dump(data, filename):
    json_str = json.dumps(data, ensure_ascii=False)
    json_str = json_str.replace(", {", ",\n{").replace("{\"", "{\n\t\"").replace("]},\n", "]\n},\n").replace("e},\n", "e\n},\n").replace("}},\n", "}\n},\n")
    json_str = re.sub(r',\s*"([^"]+)":', r', \n\t"\1":', json_str)
    with open(filename, 'w', encoding='utf-8') as f:
        f.write(json_str)

#### 构建 SELECT | FROM

In [16]:
def create_sf_queries():
    valid_sf = []
    
    for i in tqdm(range(sample_sf)):
        # 随机选择1到max_select个列
        selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
        
        # 创建查询字典
        query_dict = {
            "Type": "SF",
            "SELECT": selected_columns,
            "WHERE Indices": list(range(len(df))),  # 包含所有行
            "WHERE Total Rows": len(df),            # 总行数
            "Combination": [],                      # 没有Filter组合
            "Operators": [],                        # 没有操作符
            "WHERE": "None"                         # 没有WHERE条件
        }
        
        # 创建SCHEMA
        query_attr_list = selected_columns.copy()
        query_dict["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
        
        valid_sf.append(query_dict)
    
    return valid_sf
valid_sf = create_sf_queries()
custom_json_dump(valid_sf, "./SELECT_FROM.json")

  0%|          | 0/10 [00:00<?, ?it/s]

#### 构建 SELECT | FROM | WHERE

In [17]:
with open(valid_where_output_dir, 'r') as f:
     valid_sfw = json.load(f)

for i in tqdm(range(0, len(valid_sfw))):
    selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
    valid_sfw[i]["Type"] = "SFW"
    valid_sfw[i]["SELECT"] = selected_columns
    query_attr_list = [i for i in selected_columns]
    for k in valid_sfw[i]["Combination"]:
         if k[0] not in query_attr_list:
             query_attr_list += [k[0]]
    valid_sfw[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)

sampled_valid_sfw = random.sample(valid_sfw, sample_sfw)

# with open("./SELECT_FROM_WHERE.json", 'w') as f:
#      json.dump(sampled_valid_sfw, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfw, "./SELECT_FROM_WHERE.json")

  0%|          | 0/3265 [00:00<?, ?it/s]

#### 构建 SELECT | FROM | WHERE | GROUP BY
- 返回多个表，每个组一个表

In [18]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwg = []
for i in tqdm(range(0, len(valid_where))):
     # while True:
     #      selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_groupby))
     #      if not (set(selected_columns) & set(formatted_attr)):
     #           break
     selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
     valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
     groupby_columns = random.sample(valid_groupby_attrs, 1)
     row_indices = valid_where[i].get("WHERE Indices", [])
     filtered_df = df.loc[row_indices]
     grouped = filtered_df.groupby(groupby_columns)
     remaining_rows = grouped.size().reset_index()
     if len(remaining_rows) > (min_rows // 2):
          valid_where[i]["Type"] = "SFWG"
          valid_where[i]["SELECT"] = selected_columns + groupby_columns if groupby_columns not in selected_columns else selected_columns
          valid_where[i]["GROUP BY Total Rows (Groups)"] = len(remaining_rows)
          valid_where[i]["GROUP BY"] = groupby_columns
          query_attr_list = [i for i in groupby_columns]
          for k in valid_where[i]["Combination"]:
               if k[0] not in query_attr_list:
                    query_attr_list += [k[0]]
          valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
          valid_sfwg.append(valid_where[i])

sampled_valid_sfwg = random.sample(valid_sfwg, sample_sfwg)

# with open("./SELECT_FROM_WHERE_GROUPBY.json", 'w') as f:
#      json.dump(sampled_valid_sfwg, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwg, "./SELECT_FROM_WHERE_GROUPBY.json")

  0%|          | 0/3265 [00:00<?, ?it/s]

#### 构建 SELECT | FROM | WHERE | AGGREGATION
- 支持的聚合函数：SUM | MAX | MIN | AVG | COUNT
- TODO：加入语义聚合
- 只在一个属性上进行聚合

In [19]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwa = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
max_COUNT = len(valid_where) // 4

for i in tqdm(range(0, len(valid_where))):
     numerical = False
     while True:
          if max_COUNT > 0:
               selected_aggregation_column = random.sample(list(attr_value_dict.keys()), 1)
               if selected_aggregation_column not in formatted_attr:
                    break
          else:
               selected_aggregation_column = random.sample(list(non_formatted_attr), 1)
               break
          
     if selected_aggregation_column[0] in non_formatted_attr:
          function = random.choice(aggregation_functions[1:])
          numerical = True
     else:
          max_COUNT -= 1
          function = 'COUNT'
     if function == "COUNT" and random.uniform(0, 1) > 0.5:
          selected_aggregation_column = ["*"]
     valid_where[i]["Type"] = "SFWA"
     valid_where[i]["SELECT"] = [f"{function}({selected_aggregation_column[0]})"]
     query_attr_list = [i for i in selected_aggregation_column]
     for k in valid_where[i]["Combination"]:
          if k[0] not in query_attr_list:
               query_attr_list += [k[0]]
     valid_where[i]["AGGREGATION"] = selected_aggregation_column
     valid_where[i]["AGGREGATION Function"] = function
     valid_where[i]["Numerical"] = numerical
     valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
     valid_sfwa.append(valid_where[i])

sampled_valid_sfwa = random.sample(valid_sfwa, sample_sfwa)

# with open("./SELECT_FROM_WHERE_AGGREGATION.json", 'w') as f:
#      json.dump(sampled_valid_sfwa, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwa, "./SELECT_FROM_WHERE_AGGREGATION.json")

  0%|          | 0/3265 [00:00<?, ?it/s]

#### 构建 SELECT | FROM | GROUP BY | AGGREGATION

In [20]:
valid_sfag = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'MEAN', 'SUM']
sample_sfag = 20  # 定义生成数量
max_COUNT = sample_sfag // 4

for i in tqdm(range(sample_sfag)):
    # 选择分组列（排除多值属性）
    valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
    selected_groupby_columns = random.sample(valid_groupby_attrs, 1)
    
    # 选择聚合函数和列
    if max_COUNT > 0:
        selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
        if selected_aggregation_column[0] not in formatted_attr:
            function = random.choice(aggregation_functions[1:])
            numerical = True
        else:
            max_COUNT -= 1
            function = 'COUNT'
            numerical = False
    else:
        selected_aggregation_column = random.sample(list(set(non_formatted_attr)-set(selected_groupby_columns)), 1)
        function = random.choice(aggregation_functions[1:])
        numerical = True
    
    if function == "COUNT" and random.uniform(0, 1) > 0.5:
        selected_aggregation_column = ["*"]
    
    query_dict = {
        "Type": "SFAG",
        "GROUP BY": selected_groupby_columns,
        "SELECT": selected_groupby_columns + [f"{function}({selected_aggregation_column[0]})"],
        "WHERE Indices": list(range(len(df))),  # 包含所有行
        "WHERE Total Rows": len(df),            # 总行数
        "Combination": [],                      # 没有Filter组合
        "Operators": [],                        # 没有操作符
        "WHERE": "None",                        # 没有WHERE条件
        "AGGREGATION": selected_aggregation_column,
        "AGGREGATION Function": "AVG" if function == "MEAN" else function,
        "Numerical": numerical
    }
    
    # 创建SCHEMA
    query_attr_list = selected_aggregation_column + selected_groupby_columns
    query_dict["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
    
    valid_sfag.append(query_dict)

custom_json_dump(valid_sfag, "./SELECT_FROM_AGGREGATION_GROUPBY.json")

  0%|          | 0/20 [00:00<?, ?it/s]

#### 构建 SELECT | FROM | WHERE | GROUP BY | AGGREGATION

In [21]:
def convert_column_to_float(df, column_name):
    df[column_name] = df[column_name].apply(pd.to_numeric, errors='coerce')
    return df

for i in non_formatted_attr:
    convert_column_to_float(df, i)

In [22]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwga = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'MEAN', 'SUM']
max_COUNT = sample_sfwga // 4

for i in tqdm(range(0, len(valid_where))):
     row_indices = valid_where[i].get("WHERE Indices", [])
     filtered_df = df.loc[row_indices]
     # while True:
     #      selected_groupby_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_groupby))
     #      if not set(selected_groupby_columns) & set(formatted_attr):
     #           break
     valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
     selected_groupby_columns = random.sample(valid_groupby_attrs, 1)
     grouped = filtered_df.groupby(selected_groupby_columns)
     remaining_rows = grouped.size().reset_index(name='Group Size')

     if len(remaining_rows) > (min_rows // 2):
          while True:
               if max_COUNT > 0:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    if selected_aggregation_column not in formatted_attr:
                         break
               else:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    break
               
          if selected_aggregation_column[0] in non_formatted_attr:
               function = random.choice(aggregation_functions[1:])
               numerical = True
          else:
               max_COUNT -= 1
               function = 'COUNT'
               numerical = False
          
          if function == "COUNT" and random.uniform(0, 1) > 0.5:
               selected_aggregation_column = ["*"]
          agg_result = grouped[selected_aggregation_column].agg(function.lower()) if numerical else grouped.size()
          valid_where[i]["Type"] = "SFWGA"
          valid_where[i]["GROUP BY"] = selected_groupby_columns
          valid_where[i]["GROUP BY Total Rows"] = len(remaining_rows)
          valid_where[i]["SELECT"] = selected_groupby_columns + [f"{function}({selected_aggregation_column[0]})"]

          valid_where[i]["AGGREGATION"] = selected_aggregation_column
          if function == "MEAN":
               valid_where[i]["AGGREGATION Function"] = "AVG"
          else:
               valid_where[i]["AGGREGATION Function"] = function
          valid_where[i]["Numerical"] = numerical
          query_attr_list = [i for i in selected_aggregation_column]
          for k in valid_where[i]["Combination"]:
               if k[0] not in query_attr_list:
                    query_attr_list += [k[0]]
          for k in selected_groupby_columns:
               if k not in query_attr_list:
                    query_attr_list += [k]
          valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
          valid_sfwga.append(valid_where[i])

sampled_valid_sfwga = random.sample(valid_sfwga, sample_sfwga)

# with open("./SELECT_FROM_WHERE_GROUPBY_AGGREGATION.json", 'w') as f:
#      json.dump(sampled_valid_sfwga, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwga, "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION.json")

  0%|          | 0/3265 [00:00<?, ?it/s]

#### 构建 SELECT | FROM | WHERE | TOP-K

In [23]:
with open(valid_where_output_dir, 'r') as f:
    valid_where = json.load(f)

valid_sfwt = []
order_options = ['ASC', 'DESC']

for i in tqdm(range(0, len(valid_where))):
    row_indices = valid_where[i].get("WHERE Indices", [])
    filtered_df = df.loc[row_indices]
    while True:
        selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
        if set(selected_columns) & set(non_formatted_attr):
            order_column = random.choice(list(set(selected_columns) & set(non_formatted_attr)))
            break

    order_type = random.choice(order_options)

    limit_value = min(random.sample(limit_list, 1)[0], len(filtered_df) // 2)
    valid_where[i]["Type"] = "SFWT"
    valid_where[i]["SELECT"] = selected_columns
    valid_where[i]["LIMIT"] = limit_value
    valid_where[i]["ORDER BY"] = [order_column, order_type]
    query_attr_list = [i for i in selected_columns]
    for k in valid_where[i]["Combination"]:
         if k[0] not in query_attr_list:
             query_attr_list += [k[0]]
    valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)

    valid_sfwt.append(valid_where[i])

sampled_valid_sfwt = random.sample(valid_sfwt, sample_sfwt)

# with open("./SELECT_FROM_WHERE_TOPK.json", 'w') as f:
#     json.dump(sampled_valid_sfwt, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwt, "./SELECT_FROM_WHERE_TOPK.json")

  0%|          | 0/3265 [00:00<?, ?it/s]

#### 构建 SELECT | FROM | WHERE | GROUP BY | AGGREGATION | TOP-K

In [24]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwgat = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'MEAN', 'SUM']
order_options = ['ASC', 'DESC']
max_COUNT = sample_sfwgat // 4

for i in tqdm(range(0, len(valid_where))):
     row_indices = valid_where[i].get("WHERE Indices", [])
     filtered_df = df.loc[row_indices]

     # while True:
     #      selected_groupby_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_groupby))
     #      if not set(selected_groupby_columns) & set(formatted_attr):
     #           break
     valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
     selected_groupby_columns = random.sample(valid_groupby_attrs, 1)
     grouped = filtered_df.groupby(selected_groupby_columns)
     remaining_rows = grouped.size().reset_index(name='Group Size')

     if len(remaining_rows) > (min_rows // 2):
          while True:
               if max_COUNT > 0:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    if selected_aggregation_column not in formatted_attr:
                         break
               else:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    break
          if selected_aggregation_column[0] in non_formatted_attr:
               function = random.choice(aggregation_functions[1:])
               numerical = True
          else:
               max_COUNT -= 1
               function = 'COUNT'
               numerical = False

          if function == "COUNT" and random.uniform(0, 1) > 0.5:
               selected_aggregation_column = ["*"]

          if selected_aggregation_column == ["*"]:
               agg_result = grouped.transform(lambda x: x.notna().sum())
          else:
               agg_result = grouped[selected_aggregation_column[0]].transform(lambda x: x.notna().sum())

          # agg_result = grouped[selected_aggregation_column].agg(function.lower()) if numerical else grouped.size()
          select_set = set(selected_groupby_columns).union(set(selected_aggregation_column))
          order_set = select_set & set(non_formatted_attr)
          if not order_set:
               continue
          order_column = random.choice(list(order_set))
          order_type = random.choice(order_options)
          limit_value = min(random.sample(limit_list, 1)[0], len(agg_result // 2))
          valid_where[i]["Type"] = "SFWGAT"
          valid_where[i]["GROUP BY"] = selected_groupby_columns
          valid_where[i]["GROUP BY Total Rows"] = len(remaining_rows)
          valid_where[i]["SELECT"] = selected_groupby_columns + [f"{function}({selected_aggregation_column[0]})"]
          valid_where[i]["AGGREGATION"] = selected_aggregation_column
          if function == "MEAN":
               valid_where[i]["AGGREGATION Function"] = "AVG"
          else:
               valid_where[i]["AGGREGATION Function"] = function
          valid_where[i]["Numerical"] = numerical
          valid_where[i]["LIMIT"] = limit_value
          valid_where[i]["ORDER BY"] = [order_column, order_type]
          query_attr_list = [i for i in selected_aggregation_column]
          for k in valid_where[i]["Combination"]:
               if k[0] not in query_attr_list:
                    query_attr_list += [k[0]]
          for k in selected_groupby_columns:
               if k not in query_attr_list:
                    query_attr_list += [k]
          valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
                    
          valid_sfwgat.append(valid_where[i])

sampled_valid_sfwgat = random.sample(valid_sfwgat, sample_sfwgat)

# with open("./SELECT_FROM_WHERE_GROUPBY_AGGREGATION_TOPK.json", 'w') as f:
#     json.dump(sampled_valid_sfwgat, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwgat, "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION_TOPK.json")

  0%|          | 0/3265 [00:00<?, ?it/s]

In [None]:
import json

#### 根据构造的查询结构生成对应的SQL语句
def generate_sql_query(query_dict, table_name="LCR"):
    """
    根据查询字典生成对应的SQL语句（支持SF类型）
    """
    query_type = query_dict.get("Type", "")
    
    # 构建SELECT子句
    select_clause = "SELECT " + ", ".join(query_dict["SELECT"])
    
    # 构建FROM子句
    from_clause = f"FROM {table_name}"
    
    # 构建WHERE子句（SF类型没有WHERE）
    where_clause = ""
    if "WHERE" in query_dict and query_dict["WHERE"] != "None":
        where_clause = f"WHERE {query_dict['WHERE']}"
    
    # 构建GROUP BY子句
    group_by_clause = ""
    if "GROUP BY" in query_dict:
        group_by_clause = f"GROUP BY {', '.join(query_dict['GROUP BY'])}"
    
    # 构建ORDER BY子句
    order_by_clause = ""
    if "ORDER BY" in query_dict:
        order_col, order_type = query_dict["ORDER BY"]
        order_by_clause = f"ORDER BY {order_col} {order_type}"
    
    # 构建LIMIT子句
    limit_clause = ""
    if "LIMIT" in query_dict:
        limit_clause = f"LIMIT {query_dict['LIMIT']}"
    
    # 组合完整SQL
    sql_parts = [select_clause, from_clause]
    if where_clause:
        sql_parts.append(where_clause)
    if group_by_clause:
        sql_parts.append(group_by_clause)
    if order_by_clause:
        sql_parts.append(order_by_clause)
    if limit_clause:
        sql_parts.append(limit_clause)
    
    return "\n".join(sql_parts) + ";"

def generate_schema_sql(schema_dict, table_name="LCR"):
    """
    根据SCHEMA字典生成建表SQL语句
    """
    create_table = f"CREATE TABLE {table_name} (\n"
    columns = []
    
    for col_name, (data_type, description) in schema_dict.items():
        comment = f" COMMENT '{description}'" if description else ""
        columns.append(f"    {col_name} {data_type}{comment}")
    
    create_table += ",\n".join(columns)
    create_table += "\n);"
    
    return create_table

def save_queries_with_sql(input_files, output_dir="./sql_queries/"):
    """
    读取所有生成的查询文件，为每个查询生成对应的SQL语句并保存（支持SF类型）
    修正：使用正确的换行符格式
    """
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    for file_path in input_files:
        if not os.path.exists(file_path):
            print(f"文件不存在: {file_path}")
            continue
            
        with open(file_path, 'r', encoding='utf-8') as f:
            queries = json.load(f)
        
        # 为每个查询生成SQL
        for i, query in enumerate(queries):
            # 生成建表SQL
            schema_sql = generate_schema_sql(query["SCHEMA"])
            
            # 生成查询SQL
            query_sql = generate_sql_query(query)
            
            # 计算SELECT数量
            select_count = len(query.get("SELECT", []))
            
            # 计算Filter数量 - 使用Combination字段获取最准确的数量
            filter_count = len(query.get("Combination", []))
            
            # 组合完整SQL - 修正：使用正确的换行符 \n 而不是 \\n
            complete_sql = f"-- Query {i+1} ({query['Type']})\n"
            complete_sql += f"-- Total Rows: {query.get('WHERE Total Rows', 'N/A')}\n"
            complete_sql += f"-- SELECT: {select_count}\n"
            complete_sql += f"-- FILTER: {filter_count}\n\n"
            complete_sql += schema_sql + "\n\n"
            complete_sql += query_sql + "\n"
            complete_sql += "-" * 50 + "\n\n"
            
            query["SQL"] = {
                "schema": schema_sql,
                "query": query_sql,
                "complete": complete_sql
            }
        
        # 保存包含SQL的查询文件
        output_file = output_dir + os.path.basename(file_path).replace('.json', '_with_sql.json')
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(queries, f, ensure_ascii=False, indent=2)
        
        # 单独保存纯SQL文件
        sql_file = output_dir + os.path.basename(file_path).replace('.json', '.sql')
        with open(sql_file, 'w', encoding='utf-8') as f:
            for query in queries:
                f.write(query["SQL"]["complete"])
        
        print(f"已生成SQL文件: {sql_file}")

# 使用示例：为所有查询类型生成SQL
query_files = [
    "./SELECT_FROM.json",
    "./SELECT_FROM_WHERE.json",
    "./SELECT_FROM_WHERE_TOPK.json", 
    "./SELECT_FROM_WHERE_GROUPBY.json",
    "./SELECT_FROM_WHERE_AGGREGATION.json",
    "./SELECT_FROM_AGGREGATION_GROUPBY.json",
    "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION.json",
    "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION_TOPK.json"
]

# 生成所有SQL查询
save_queries_with_sql(query_files)

已生成SQL文件: ./sql_queries/SELECT_FROM.sql
已生成SQL文件: ./sql_queries/SELECT_FROM_WHERE.sql
已生成SQL文件: ./sql_queries/SELECT_FROM_WHERE_TOPK.sql
已生成SQL文件: ./sql_queries/SELECT_FROM_WHERE_GROUPBY.sql
已生成SQL文件: ./sql_queries/SELECT_FROM_WHERE_AGGREGATION.sql
已生成SQL文件: ./sql_queries/SELECT_FROM_AGGREGATION_GROUPBY.sql
已生成SQL文件: ./sql_queries/SELECT_FROM_WHERE_GROUPBY_AGGREGATION.sql
已生成SQL文件: ./sql_queries/SELECT_FROM_WHERE_GROUPBY_AGGREGATION_TOPK.sql


#### Filter固定

In [27]:
#### SQL文件生成器 - 按类别和类型分文件夹
import json
import random
import os
from tqdm import tqdm

# 法律数据集的语义模板定义 - 按Filter数量2:3:3:1:1分布
TEMPLATES = [
    # 1个Filter (2个类别)
    {"name": "judge_focus", "description": "特定法官案件分析", "filters": ["judge_name"]},
    {"name": "case_type_focus", "description": "特定案件类型研究", "filters": ["case_type"]},
    
    # 2个Filter (3个类别) 
    {"name": "judge_verdict", "description": "法官与判决结果关联", "filters": ["judge_name", "verdict"]},
    {"name": "location_year", "description": "审理地点与年份分析", "filters": ["hearing_location", "hearing_year"]},
    {"name": "type_outcome", "description": "案件类型与结果关系", "filters": ["case_type", "verdict"]},
    
    # 3个Filter (3个类别)
    {"name": "court_analysis", "description": "法庭审理综合分析", "filters": ["judge_name", "hearing_location", "case_type"]},
    {"name": "temporal_pattern", "description": "时间模式与结果研究", "filters": ["hearing_year", "judgment_year", "verdict"]},
    {"name": "party_analysis", "description": "当事方特征分析", "filters": ["plaintiff", "defendant", "case_type"]},
    
    # 4个Filter (1个类别)
    {"name": "complex_litigation", "description": "复杂诉讼案件深度研究", "filters": ["case_type", "judge_name", "hearing_location", "legal_fees"]},
    
    # 5个Filter (1个类别)
    {"name": "comprehensive_case_study", "description": "案件全方位分析", "filters": ["judge_name", "case_type", "verdict", "hearing_location", "legal_fees"]}
]

QUERY_TYPES = ["SFW", "SFWT", "SFWG", "SFWA", "SFAG", "SFWGA", "SFWGAT"]

class SQLGenerator:
    def __init__(self, min_result_rows=5):
        self.df = globals()['df']
        self.attr_dict = globals()['attr_value_dict']
        self.numerical_attr = globals()['numerical_attr']
        self.non_numerical_attr = globals()['non_numerical_attr']
        self.category_attr = globals()['category_attr']
        self.formatted_attr = globals()['formatted_attr']
        self.attr_desc_dict = globals()['attr_desc_dict']
        self.min_rows = min_result_rows
        
        # 分析数据分布用于放宽策略
        self.stats = self._analyze_data()
        
    def _analyze_data(self):
        """分析数据分布"""
        stats = {}
        for attr in self.attr_dict.keys():
            if attr in self.non_numerical_attr:
                try:
                    value_counts = self.df[attr].value_counts()
                    stats[attr] = {
                        "type": "categorical", 
                        "top_values": value_counts.head(5).to_dict()
                    }
                except:
                    stats[attr] = {"type": "categorical", "top_values": {}}
        return stats
        
    def get_filter_value(self, attr):
        """从attr_value_dict获取值和条件"""
        values = self.attr_dict.get(attr, [])
        if not values:
            return None, None
            
        if attr in self.numerical_attr:
            val = random.choice(values)
            op = random.choice(["==", ">", ">=", "<", "<="])
            return val, f"{attr} {op} {val}"
        else:
            val = random.choice(values)
            return val, f"{attr} == '{val}'"
    
    def get_filter_value(self, attr, relaxation_level=0):
        """从attr_value_dict获取值和条件，支持放宽策略"""
        values = self.attr_dict.get(attr, [])
        if not values:
            return None, None
        
        if attr in self.numerical_attr:
            val = random.choice(values)
            
            # 根据放宽级别调整操作符
            if relaxation_level == 0:
                # 正常策略：所有操作符
                op = random.choice(["==", ">", ">=", "<", "<="])
            elif relaxation_level == 1:
                # 第一次放宽：偏向宽松操作符
                op = random.choice([">=", "<=", "==", ">", "<"])
            elif relaxation_level == 2:
                # 第二次放宽：更多宽松操作符
                op = random.choice([">=", "<=", ">=", "<=", "=="])
            else:
                # 最大放宽：只用最宽松的操作符
                op = random.choice([">=", "<="])
                
            # 对于特定数值属性的智能调整
            if relaxation_level >= 1:
                if attr == "Age":
                    # 年龄放宽：选择较小的值用于>=，较大的值用于<=
                    sorted_vals = sorted(values)
                    if op in [">=", ">"]:
                        val = random.choice(sorted_vals[:len(sorted_vals)//2])
                    elif op in ["<=", "<"]:
                        val = random.choice(sorted_vals[len(sorted_vals)//2:])
                elif attr == "Awards":
                    # 奖项放宽：降低阈值
                    if op in [">=", ">"]:
                        min_awards = min(values)
                        val = min_awards if relaxation_level >= 2 else random.choice([v for v in values if v <= min_awards + 2])
            
            return val, f"{attr} {op} {val}"
        else:
            val = random.choice(values)
            return val, f"{attr} == '{val}'"
    
    def apply_filters_with_relaxation(self, filters_config, max_relaxation=3):
        """应用过滤条件，支持自动放宽策略"""
        
        for relaxation_level in range(max_relaxation + 1):
            result_indices, applied_conditions = self._try_apply_filters(filters_config, relaxation_level)
            
            # 如果有结果，返回
            if len(result_indices) >= 2:  # 最少2行结果
                return result_indices, applied_conditions, relaxation_level
        
        # 所有级别都失败，返回失败标记
        return None, None, "FAILED_ALL_RELAXATION"
    
    def get_filter_value(self, attr, relaxation_level=0):
        """从attr_value_dict获取值和条件，支持放宽策略"""
        values = self.attr_dict.get(attr, [])
        if not values:
            return None, None
        
        if attr in self.numerical_attr:
            val = random.choice(values)
            
            # 根据放宽级别调整操作符，让条件更宽松
            if relaxation_level == 0:
                # 正常策略：所有操作符
                op = random.choice(["==", ">", ">=", "<", "<="])
            elif relaxation_level == 1:
                # 第一次放宽：偏向宽松操作符，避免等于
                op = random.choice([">=", "<=", ">", "<", ">=", "<="])
            elif relaxation_level == 2:
                # 第二次放宽：只用最宽松的操作符
                op = random.choice([">=", "<="])
            else:
                # 最大放宽：选择能包含更多数据的阈值
                op = random.choice([">=", "<="])
                
            # 对于特定数值属性的智能调整
            if relaxation_level >= 1:
                if attr == "Age":
                    # 年龄放宽：选择更极端的值来包含更多人
                    sorted_vals = sorted(values)
                    if op in [">=", ">"]:
                        # 选择较小的年龄值，让更多人满足条件
                        val = random.choice(sorted_vals[:len(sorted_vals)//3])
                    elif op in ["<=", "<"]:
                        # 选择较大的年龄值，让更多人满足条件
                        val = random.choice(sorted_vals[2*len(sorted_vals)//3:])
                elif attr == "Awards":
                    # 奖项放宽：降低门槛
                    if op in [">=", ">"]:
                        min_awards = min([v for v in values if isinstance(v, (int, float))])
                        val = min_awards if relaxation_level >= 2 else random.choice([v for v in values if v <= min_awards + 1])
            
            return val, f"{attr} {op} {val}"
        else:
            # 非数值属性：选择更常见的值
            if relaxation_level >= 1 and attr in self.stats and self.stats[attr]["type"] == "categorical":
                # 选择出现频率更高的值
                top_values = list(self.stats[attr]["top_values"].keys())[:3]  # 前3个最常见的值
                common_values = [v for v in top_values if v in values]
                if common_values:
                    val = random.choice(common_values)
                else:
                    val = random.choice(values)
            else:
                val = random.choice(values)
            return val, f"{attr} == '{val}'"
    
    def _try_apply_filters(self, filters_config, relaxation_level):
        """尝试应用过滤条件，严格保持Filter数量"""
        result_indices = set(range(len(self.df)))
        applied_conditions = []
        
        # 必须使用所有Filter，数量必须与模板定义完全一致
        for attr in filters_config:
            val, condition = self.get_filter_value(attr, relaxation_level)
            if not val or not condition:
                # 如果某个Filter无法生成条件，整个查询失败
                return [], []
                
            try:
                # 测试过滤效果
                if condition.endswith("'"):  # 字符串条件
                    mask = self.df[attr].apply(lambda x: 
                        str(x).strip().lower() == str(val).strip().lower() or
                        (isinstance(x, str) and '||' in x and str(val) in x.split('||')))
                elif ">=" in condition:
                    mask = self.df[attr] >= val
                elif ">" in condition:
                    mask = self.df[attr] > val
                elif "<=" in condition:
                    mask = self.df[attr] <= val
                elif "<" in condition:
                    mask = self.df[attr] < val
                elif "==" in condition:
                    mask = self.df[attr] == val
                else:
                    # 如果条件格式错误，整个查询失败
                    return [], []
                
                new_indices = set(self.df[mask].index)
                result_indices &= new_indices
                applied_conditions.append(condition)
                
            except Exception:
                # 如果任何Filter应用失败，整个查询失败
                return [], []
        
        # 验证Filter数量完全匹配
        if len(applied_conditions) != len(filters_config):
            return [], []
        
        return list(result_indices), applied_conditions
    
    def create_schema_sql(self, attrs, table_name="Wikiart"):
        """生成建表SQL"""
        schema_parts = []
        for attr in set(attrs):
            if '(' in attr:  # 跳过聚合函数
                continue
            if attr in self.numerical_attr:
                schema_parts.append(f"    {attr} FLOAT")
            elif attr in self.formatted_attr:
                schema_parts.append(f"    {attr} DATE")
            else:
                schema_parts.append(f"    {attr} VARCHAR(255)")
        
        return f"CREATE TABLE {table_name} (\n" + ",\n".join(schema_parts) + "\n);"
    
    def generate_query_sql(self, template, qtype, query_id):
        """生成单个查询的SQL，带自动放宽机制"""
        # 使用自适应放宽策略
        if qtype == "SFAG":
            indices = list(range(len(self.df)))
            conditions = []
            relaxation_used = 0
        else:
            result = self.apply_filters_with_relaxation(template["filters"])
            
            if result[2] == "FAILED_ALL_RELAXATION":
                return None
            
            indices, conditions, relaxation_used = result
            
            if len(indices) < 2 or not conditions:
                return None
                
        result = self.apply_filters_with_relaxation(template["filters"])
        
        # 检查是否完全失败
        if result[2] == "FAILED_ALL_RELAXATION":
            # 生成失败标注的SQL注释
            header = f"-- Query {query_id} - {qtype} [GENERATION FAILED]\n"
            header += f"-- Template: {template['name']}\n"
            header += f"-- Description: {template['description']}\n"
            header += f"-- Required Filters: {len(template['filters'])} (STRICTLY MAINTAINED)\n"
            header += f"-- Filter List: {', '.join(template['filters'])}\n"
            header += f"-- Status: All {len(template['filters'])} filters must be applied but combination yields no results\n"
            header += f"-- Reason: Filter combination too restrictive for current dataset even with maximum value relaxation\n"
            header += f"-- Semantic Integrity: PRESERVED - No filter reduction allowed\n\n"
            
            # 创建一个示例SQL结构用于说明
            example_sql = f"-- Required SQL structure (unfulfillable):\n"
            example_sql += f"-- CREATE TABLE Wikiart (...columns for {', '.join(template['filters'])}...);\n"
            example_sql += f"-- SELECT ... FROM Wikiart WHERE <ALL {len(template['filters'])} FILTERS REQUIRED>;\n\n"
            
            complete_sql = header + example_sql + "-" * 60 + "\n\n"
            return complete_sql
        
        indices, conditions, relaxation_used = result
        
        if len(indices) < 2 or not conditions:
            return None
        
        # 基本列选择
        available_attrs = list(self.attr_desc_dict.keys())
        select_cols = random.sample(available_attrs, random.randint(1, 3))
        
        # 根据查询类型构建SQL
        sql_parts = []
        schema_attrs = select_cols.copy()
        
        if qtype == "SFW":
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
            
        elif qtype == "SFWT":
            numeric_cols = [c for c in select_cols if c in self.numerical_attr]
            if not numeric_cols:
                numeric_cols = [random.choice(self.numerical_attr)]
                select_cols.extend(numeric_cols)
                schema_attrs.extend(numeric_cols)
            order_col = random.choice(numeric_cols)
            order_dir = random.choice(["ASC", "DESC"])
            limit_val = random.choice([5, 10, 15, 20])
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
            sql_parts.append("FROM Wikiart")
            sql_parts.append(f"WHERE {' AND '.join(conditions)}")
            sql_parts.append(f"ORDER BY {order_col} {order_dir}")
            sql_parts.append(f"LIMIT {limit_val}")
            
        elif qtype == "SFWG":
            group_col = random.choice(self.category_attr)
            if group_col not in select_cols:
                select_cols.append(group_col)
                schema_attrs.append(group_col)
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
            sql_parts.append("FROM Wikiart")
            sql_parts.append(f"WHERE {' AND '.join(conditions)}")
            sql_parts.append(f"GROUP BY {group_col}")
            
        elif qtype == "SFWA":
            func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
            if func == "COUNT" and random.random() < 0.3:
                agg_col = "*"
                select_cols = [f"{func}(*)"]
            else:
                agg_col = random.choice(self.numerical_attr)
                select_cols = [f"{func}({agg_col})"]
                schema_attrs = [agg_col]
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
            
            
        elif qtype == "SFWGA":
            group_col = random.choice(self.category_attr)
            func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
            if func == "COUNT" and random.random() < 0.3:
                agg_col = "*"
                select_cols = [group_col, f"{func}(*)"]
            else:
                agg_col = random.choice([c for c in self.numerical_attr if c != group_col])
                select_cols = [group_col, f"{func}({agg_col})"]
                schema_attrs = [group_col, agg_col]
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
            sql_parts.append("FROM Wikiart")
            sql_parts.append(f"WHERE {' AND '.join(conditions)}")
            sql_parts.append(f"GROUP BY {group_col}")
            
        elif qtype == "SFWGAT":
            group_col = random.choice(self.category_attr)
            func = random.choice(["MAX", "MIN", "AVG", "SUM"])
            agg_col = random.choice([c for c in self.numerical_attr if c != group_col])
            order_dir = random.choice(["ASC", "DESC"])
            limit_val = random.choice([5, 10, 15])
            select_cols = [group_col, f"{func}({agg_col})"]
            schema_attrs = [group_col, agg_col]
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
            sql_parts.append("FROM Wikiart")
            sql_parts.append(f"WHERE {' AND '.join(conditions)}")
            sql_parts.append(f"GROUP BY {group_col}")
            sql_parts.append(f"ORDER BY {agg_col} {order_dir}")
            sql_parts.append(f"LIMIT {limit_val}")
        
        # 添加基本的FROM和WHERE（如果还没有）
        if len(sql_parts) == 1:  # 只有SELECT
            sql_parts.append("FROM Wikiart")
            if conditions:
                sql_parts.append(f"WHERE {' AND '.join(conditions)}")
        
        # 添加过滤条件中的属性到schema
        for condition in conditions:
            attr_name = condition.split()[0]
            if attr_name not in schema_attrs:
                schema_attrs.append(attr_name)
        
        # 生成完整SQL
        schema_sql = self.create_schema_sql(schema_attrs)
        query_sql = "\n".join(sql_parts) + ";"
        
        header = f"-- Query {query_id} - {qtype}\n"
        header += f"-- Template: {template['name']}\n"
        header += f"-- Description: {template['description']}\n"
        header += f"-- Result Rows: {len(indices)}\n"
        header += f"-- Filters Applied: {len(conditions)}/{len(template['filters'])} (EXACT MATCH REQUIRED)"
        if len(conditions) != len(template['filters']):
            header += " [ERROR: Filter count mismatch]"
        elif relaxation_used > 0:
            header += f" (Values relaxed {relaxation_used} times)"
        header += "\n\n"
        
        complete_sql = header + schema_sql + "\n\n" + query_sql + "\n\n" + "-" * 60 + "\n\n"
        
        return complete_sql

def generate_all_sql_files():
    """生成所有SQL文件，按类别和类型分文件夹"""
    print("生成SQL文件...")
    
    # 创建主输出目录
    base_dir = "./Fixed_filters/"
    os.makedirs(base_dir, exist_ok=True)
    
    generator = SQLGenerator()
    queries_per_type = 5  # 每个类型生成5个查询
    
    # 统计信息
    generation_stats = {}
    
    # 为每个模板创建文件夹
    for template in tqdm(TEMPLATES, desc="Processing templates"):
        template_dir = os.path.join(base_dir, template["name"])
        os.makedirs(template_dir, exist_ok=True)
        
        template_stats = {}
        
        # 为每个查询类型生成SQL文件
        for qtype in QUERY_TYPES:
            sql_content = []
            query_id = 1
            
            # 生成多个查询
            generated = 0
            attempts = 0
            max_attempts = queries_per_type * 5  # 增加尝试次数
            
            while generated < queries_per_type and attempts < max_attempts:
                attempts += 1
                
                sql = generator.generate_query_sql(template, qtype, query_id)
                if sql:
                    sql_content.append(sql)
                    generated += 1
                    query_id += 1
            
            template_stats[qtype] = generated
            
            # 保存SQL文件
            if sql_content:
                filename = os.path.join(template_dir, f"{qtype}.sql")
                with open(filename, 'w', encoding='utf-8') as f:
                    f.write(f"-- {template['description']} - {qtype} 查询集合\n")
                    f.write(f"-- 模板: {template['name']}\n")
                    f.write(f"-- Filter数量: {len(template['filters'])}\n")
                    f.write("-- " + "=" * 60 + "\n\n")
                    f.write("".join(sql_content))
            else:
                # 创建空文件说明原因
                filename = os.path.join(template_dir, f"{qtype}.sql")
                with open(filename, 'w', encoding='utf-8') as f:
                    f.write(f"-- {template['description']} - {qtype} 查询集合\n")
                    f.write(f"-- 模板: {template['name']}\n")
                    f.write(f"-- Filter数量: {len(template['filters'])}\n")
                    f.write("-- 注意: 由于Filter条件过于严格，未能生成有效查询\n")
                    f.write("-- 建议: 可以手动调整Filter条件或降低最小结果行数要求\n")
        
        generation_stats[template["name"]] = template_stats
    
    # 输出精简统计信息
    print(f"生成完成!")
    
    total_queries = 0
    failed_queries = 0
    
    for template_name, type_stats in generation_stats.items():
        for qtype, count in type_stats.items():
            total_queries += count
    
    # 统计失败的查询（通过检查SQL文件中的FAILED标记）
    for template in TEMPLATES:
        template_dir = os.path.join(base_dir, template["name"])
        for qtype in QUERY_TYPES:
            sql_file = os.path.join(template_dir, f"{qtype}.sql")
            if os.path.exists(sql_file):
                with open(sql_file, 'r', encoding='utf-8') as f:
                    content = f.read()
                    failed_queries += content.count("[GENERATION FAILED]")
    
    success_queries = total_queries - failed_queries
    
    print(f"成功: {success_queries} | 失败: {failed_queries} | 总计: {total_queries}")
    print(f"模板: {len(TEMPLATES)} | 查询类型: {len(QUERY_TYPES)}")
    
    if failed_queries > 0:
        print(f"注意: {failed_queries} 个查询因Filter组合过于严格而失败，但保持了完整的Filter数量")
    
    print(f"文件保存在: {base_dir}")

# 执行生成
generate_all_sql_files()

生成SQL文件...


Processing templates: 100%|██████████| 10/10 [00:04<00:00,  2.46it/s]

生成完成!
成功: 174 | 失败: 26 | 总计: 200
模板: 10 | 查询类型: 7
注意: 26 个查询因Filter组合过于严格而失败，但保持了完整的Filter数量
文件保存在: ./Fixed_filters/



