In [None]:
import re
import json
import pickle
import random
import openpyxl
import itertools
import pandas as pd
from datetime import datetime
from tqdm.autonotebook import tqdm
from collections import defaultdict

#### 定义输入输出路径，并加载数据
- ```dataset_dir```：标注好的数据集表格路径
- ```statistics_output_dir```：统计数据表输出路径，包括属性值、选择率、基数
- ```valid_where_output_dir```：所有有效谓词组合的输出路径
- **注意**：区分标注好的表格是 .csv 还是 .xlsx 格式，相应改变 pandas 读取形式


In [None]:
dataset_dir = r"/data2/liujinqi/Benchmark/Query/Wikiart_AutoConstruct/Text/Wikiart_All_Attr_Joint.csv"
statistics_output_dir = r"/data2/liujinqi/Benchmark/Query/Wikiart_AutoConstruct/Text/Wikiart_Statistics_Text.csv"
valid_where_output_dir = r"/data2/liujinqi/Benchmark/Query/Wikiart_AutoConstruct/Text/valid_WHERE.json"
# df = pd.read_csv(dataset_dir)
df = pd.read_csv(dataset_dir, encoding='ISO-8859-1')

#### 定义属性，方便后面按不同属性类型设计不同的构造方法
- ```attr_desc_dict```：全部属性的集合，以及对应的自然语言描述
- ```non_numerical_attr```：非数值属性的集合
- ```numerical_attr```：数值属性的集合
- ```non_formatted_attr```：非格式化数值属性
- ```formatted_attr```：格式化数值属性，如日期
- ```category_attr```：固定类别的属性
- ```multi_value_attributes```：多值的属性，用"||"分隔或者逗号分隔

In [None]:
# 属性描述字典
attr_desc_dict = {
    "Name": "",
    "Nationality": "",
    "Birth_date": "",
    "Death_date": "",
    "Age": "",
    "Century": "",
    "Zodiac": "",
    "Birth_country": "", 
    "Birth_city": "",
    "Birth_continent": "",  
    "Death_country": "",
    "Death_city": "",
    "Field": "",
    "Genre": "",
    "Marriage": "",
    "Art_institution": "",
    "Teaching": "",
    "Awards": "",
    "Style": "",
    "Image_genre": "",
    "Color": "",
    "Tone": "",
    "Composition": "",
}

# 定义表特有的属性
text_attributes = ["Name", "Nationality", "Birth_date", "Death_date", "Age", 
    "Century", "Zodiac", "Birth_country", "Birth_city", "Birth_continent", "Death_country", 
    "Death_city", "Field", "Genre", "Marriage", "Art_institution",  "Teaching", "Awards"]

image_attributes = ["Style", "Image_genre", "Color", "Tone", "Composition"]
# image_attributes = ["Name", "Artwork_URL", "Style", "Genre", "Theme", "Object", "Color", 
#     "Tone", "Composition", "Regional_feature", "Art_movement", "Person_count"]

# 多值属性
multi_value_attributes = ["Genre", "Field", "Nationality", "Art_movement"]

# 属性分类
non_numerical_attr = ["Name", "Nationality", "Birth_date", "Death_date", 
    "Century", "Zodiac", "Birth_country", "Birth_city", "Birth_continent", "Death_country", 
    "Death_city", "Field", "Genre", "Marriage", "Art_institution", "Teaching", "Style", "Image_genre", "Color", "Tone", "Composition"]

# numerical_attr = ["Person_count", "Age", "Awards"]
numerical_attr = ["Age", "Awards"]
# non_formatted_attr = ["Awards", "Person_count", "Age"]
non_formatted_attr = ["Awards", "Age"]
formatted_attr = ["Birth_date", "Death_date"]
category_attr = ["Zodiac", "Birth_continent", "Marriage", "Century", "Marriage", "Teaching"]

#### 生成统计信息
- 属性 | 属性值 | 选择率 | 基数
- 用于后续构造 Filter

In [None]:
statistics = pd.DataFrame()

for column in attr_desc_dict.keys():
    if column in multi_value_attributes:
        expanded_values = []
        original_row_count = {} 
        
        for idx, value in enumerate(df[column]):
            if pd.isna(value):
                expanded_values.append(None)
                original_row_count[None] = original_row_count.get(None, 0) + 1
            elif isinstance(value, str) and '||' in value:
                split_values = [v.strip() for v in value.split('||') if v.strip()]
                expanded_values.extend(split_values)
                for split_val in split_values:
                    original_row_count[split_val] = original_row_count.get(split_val, 0) + 1
            else:
                expanded_values.append(value)
                original_row_count[value] = original_row_count.get(value, 0) + 1
        
        expanded_series = pd.Series(expanded_values)
        value_counts = expanded_series.value_counts()
        
        selectivities = pd.Series({
            k: round(v / len(df), 3) for k, v in original_row_count.items()
        })
        
    else:
        value_counts = df[column].value_counts()
        selectivities = df[column].value_counts(normalize=True).round(3)
    
    null_count = df[column].isnull().sum()
    
    result_df = pd.DataFrame({
        f"{column}": list(value_counts.index) + ["(null)"],
        'Count': list(value_counts.values) + [null_count],
        'Selectivity': [selectivities.get(val, 0) for val in value_counts.index] + [round(null_count / len(df), 3)]
    })
    
    if statistics.empty:
        statistics = result_df
    else:
        statistics = pd.concat([statistics, result_df], axis=1)

statistics.to_csv(statistics_output_dir, index=False)
statistics

#### 汇总可取属性值，以便后续据此构造 Filter
- ```attr_value_dict```：所有可取属性值的集合，每个属性有多个可取属性值，根据统计信息来定

In [None]:
attr_value_dict = {
    "Name": ["Oswald Achenbac", "Ben Shahn", "Silvestro Lega", "Samuel Finley Breese Morse", 
             "Ivan Grohar", "Oscar Agustín Alejandro Schulz Solari", "Thomas Girtin"],
    "Nationality": ["American", "French", "British", "German", "Italian", "Japanese", 
                    "Russian", "Australia", "Canadian", "Brazilian", "Mexican"],
    "Birth_date": ["1905/4/27", "1905/4/28", "1905/4/22", "1905/4/19", "1898/9/12"],
    "Death_date": ["2005/4/22", "2011/4/8", "1943/1/13", "1986/1/1", "1963/4/9"],
    "Age": [83, 17, 76, 15, 82, 78, 70, 85, 9, 65, 66, 87],
    "Century": ["19th-20th", "20th", "20th-21st", "19th"],
    "Zodiac": ["Aries", "Taurus", "Pisces", "Gemini", "Leo", "Aquarius"],
    "Birth_country": ["United States", "France", "United Kingdom", "Germany", "Italy"],
    "Birth_city": ["Paris", "London", "New York", "Philadelphia", "Chicago"],
    "Birth_continent": ["Europe", "North America", "Asia", "South America"],
    "Death_country": ["United States", "France", "United Kingdom", "Italy"],
    "Death_city": ["Paris", "New York", "London", "New York City"],
    "Field": ["Painting", "Sculpture", "Printmaking", "Illustration", "Photography"],
    "Genre": ["Abstract", "Landscape", "Portrait", "Figurative", "Surrealism"],
    "Marriage": ["Married", "Unmarried", "Remarried", "Divorced", "Widowed"],
    "Art_institution": ["Royal Academy of Arts", "Self-taught", "Art Students League of New York"],
    "Teaching": [0, 1],
    "Awards": [0, 1, 2, 3, 4, 5],
    "Style": ["Expressionism", "Romanticism", "Conceptual Art", "Abstract Expressionism", "Impressionism", "Realism",
              "Surrealism", "Pop Art", "Art Nouveau (Modern)", "Minimalism", "Post-Impressionism"],
    "Image_genre": ["Portrait", "Landscape", "Still Life", "Abstract", "Sculpture", "Figurative", "Genre Painting"],
    "Color": ["Earth Tones", "Blue", "Green", "Red", "Black And White", "Brown", "Yellow", "Multicolored", "White"],
    "Tone":["Neutral", "Bright", "Dark", "Warm", "Soft", "Light"],
    "Composition": ["Balanced", "Asymmetrical", "Centralized", "Dynamic", "Symmetrical",
                    "Central Focus", "Geometric", "Centered", "Minimalist", "Geometric Arrangement"],	
}

#### 定义查询构造数量
- ```max_filters```：WHERE 语句最大 Filter 数量
- ```min_rows```：结果表最少行数
- ```max_select```：SELECT 语句最大属性数量
- ```limit_list```：LIMIT 语句候选值
- ```sample_sfw```：SELECT | FROM | WHERE 查询数量
- ```sample_sfwt```：SELECT | FROM | WHERE | TOP-K 查询数量
- ```sample_sfwg```：SELECT | FROM | WHERE | GROUP BY 查询数量
- ```sample_sfwa```：SELECT | FROM | WHERE | AGGREGATION 查询数量
- ```sample_sfwga```：SELECT | FROM | WHERE | GROUP BY | AGGREGATION 查询数量
- ```sample_sfwgat```：SELECT | FROM | WHERE | GROUP BY | AGGREGATION | TOP-K 查询数量

In [None]:
max_filters = 5
min_rows = 5
max_select = 3
limit_list = [1, 2, 5, 10, 20, 50]
sample_sf = 10  
sample_sfw = 20
sample_sfwt = 20
sample_sfwg = 20
sample_sfwa = 20
sample_sfga = 10
sample_sfwga = 30
sample_sfwgat = 20

#### 定义 Filter 具体执行方法
- 不同类型的数据有不同的执行方法
- 区分非数值型、数值型、日期型
- 比较运算包括：大于（只限于数值型）、小于（只限于数值型）、等于（任意类型）
- **注意**：需要提前确认好日期型数据的具体格式，做好特殊情况相应处理

In [None]:
def non_numerical_equal_to(value, condition):
     #print(f'value:{value}, condition:{condition}')
     try:
          value = str(value).lower().strip()
          condition = str(condition).lower().strip()
     except:
          print("Invalid data (non_numerical_equal_to). value:[%s] | condition: [%s]." % (value, condition))
     return value == condition

def number_greater_than(value, condition):
     if pd.isna(value):
          value = 0.00
     try:
          value = round(float(value), 2)
          condition = round(float(condition), 2)
     except:
          print("Invalid data (number_greater_than). value:[%s] | condition: [%s]." % (value, condition))
     return value > condition

def number_less_than(value, condition):
     if pd.isna(value):
          value = 0.00
     try:
          value = round(float(value), 2)
          condition = round(float(condition), 2)
     except:
          print("Invalid data (number_less_than). value:[%s] | condition: [%s]." % (value, condition))
     return value < condition

def number_equal_to(value, condition):
     if pd.isna(value):
          value = 0.00
     try:
          value = round(float(value), 2)
          condition = round(float(condition), 2)
     except:
          print("Invalid data (number_equal_to). value:[%s] | condition: [%s]." % (value, condition))
     return value == condition

def parse_date(date_str):
     if pd.isna(date_str):
          return None
     for fmt in ['%Y/%m/%d', '%Y/%m', '%Y','%Y-%m-%d', '%Y-%m']:
          try:
               return datetime.strptime(date_str, fmt)
          except ValueError:
               continue
     return None

def date_greater_than(value, condition):
     date_value = parse_date(value)
     condition_date = parse_date(condition)
     if date_value is None or condition_date is None:
          return False
     return date_value > condition_date

def date_less_than(value, condition):
     date_value = parse_date(value)
     condition_date = parse_date(condition)
     if date_value is None or condition_date is None:
          return False
     return date_value < condition_date

def date_equal_to(value, condition):
     date_value = parse_date(value)
     condition_date = parse_date(condition)
     if date_value is None or condition_date is None:
          return False
     return date_value == condition_date

def parse_century(value):
    if pd.isna(value):
        return None
    try:
        match = re.match(r"(\d+)(?:th|st|nd|rd)(?:-(\d+)(?:th|st|nd|rd))?", value)
        if not match:
            return None

        start_century = int(match.group(1))
        end_century = int(match.group(2)) if match.group(2) else start_century
        return (start_century - 1) * 100, end_century * 100

    except ValueError:
        return None


def century_greater_than(value, condition):
    value_parsed = parse_century(value)
    condition_parsed = parse_century(condition)

    if value_parsed is None or condition_parsed is None:
        return False

    return value_parsed[0] > condition_parsed[1]

def century_less_than(value, condition):
    value_parsed = parse_century(value)
    condition_parsed = parse_century(condition)

    if value_parsed is None or condition_parsed is None:
        return False

    return value_parsed[0] < condition_parsed[1]

def number_greater_equal(value, condition):
    if pd.isna(value):
        value = 0.00
    try:
        value = round(float(value), 2)
        condition = round(float(condition), 2)
    except:
        print("Invalid data (number_greater_equal). value:[%s] | condition: [%s]." % (value, condition))
    return value >= condition

def number_less_equal(value, condition):
    if pd.isna(value):
        value = 0.00
    try:
        value = round(float(value), 2)
        condition = round(float(condition), 2)
    except:
        print("Invalid data (number_less_equal). value:[%s] | condition: [%s]." % (value, condition))
    return value <= condition

def date_greater_equal(value, condition):
    date_value = parse_date(value)
    condition_date = parse_date(condition)
    if date_value is None or condition_date is None:
        return False
    return date_value >= condition_date

def date_less_equal(value, condition):
    date_value = parse_date(value)
    condition_date = parse_date(condition)
    if date_value is None or condition_date is None:
        return False
    return date_value <= condition_date

def century_greater_equal(value, condition):
    value_parsed = parse_century(value)
    condition_parsed = parse_century(condition)
    if value_parsed is None or condition_parsed is None:
        return False
    return value_parsed[0] >= condition_parsed[0]

def century_less_equal(value, condition):
    value_parsed = parse_century(value)
    condition_parsed = parse_century(condition)
    if value_parsed is None or condition_parsed is None:
        return False
    return value_parsed[1] <= condition_parsed[1]


#### 枚举所有可能的 Filter
- Filter 的可能取值来自于字典：```attr_value_dict```
- 用一个列表记录满足 Filter 的行索引
- 输入：```dataset_dir```，即原始标注数据表
- 输出：```filter_dict.json```，包含各属性，每个属性中包含所有可能的过滤条件及对应满足条件的行索引号
- 这一步构建好所有可能的 Filter，后面从此 Filter 集合中选取 Filter 进行排列组合，构造 WHERE 语句

In [None]:
#### 增强的Filter构造函数 - 完全支持多值属性

def enhanced_non_numerical_equal_to_with_split(cell, condition, multi_value_attrs, separator='||'):
    try:
        # 处理空值
        if pd.isna(cell) or pd.isna(condition):
            return pd.isna(cell) and pd.isna(condition)
        
        # 转换为字符串并清理
        cell_str = str(cell).strip()
        condition_str = str(condition).strip()
        
        # 检查是否包含分隔符（多值）
        if separator in cell_str:
            # 拆分多值并检查是否包含目标值
            cell_values = [v.strip().lower() for v in cell_str.split(separator) if v.strip()]
            condition_lower = condition_str.lower()
            return condition_lower in cell_values
        else:
            # 单值比较
            return cell_str.lower() == condition_str.lower()
            
    except Exception as e:
        print(f"比较出错 - cell: [{cell}], condition: [{condition}], error: {e}")
        return False

def get_multi_value_statistics(df, column, multi_value_attrs, separator='||'):
    """
    获取多值属性的详细统计信息
    
    Returns:
    dict: 包含各种统计信息的字典
    """
    if column not in multi_value_attrs:
        return {"is_multi_value": False}
    
    # 分析多值情况
    multi_value_rows = 0
    total_values = 0
    value_distribution = {}
    
    for cell_value in df[column].dropna():
        if isinstance(cell_value, str) and separator in cell_value:
            multi_value_rows += 1
            values = [v.strip() for v in cell_value.split(separator) if v.strip()]
            total_values += len(values)
            
            for val in values:
                value_distribution[val] = value_distribution.get(val, 0) + 1
        else:
            total_values += 1
            val = str(cell_value).strip()
            value_distribution[val] = value_distribution.get(val, 0) + 1
    
    return {
        "is_multi_value": True,
        "multi_value_rows": multi_value_rows,
        "total_rows": len(df[column].dropna()),
        "total_expanded_values": total_values,
        "unique_values": len(value_distribution),
        "value_distribution": value_distribution,
        "avg_values_per_row": total_values / len(df[column].dropna()) if len(df[column].dropna()) > 0 else 0
    }
    
# 重新读取数据确保一致性
df = pd.read_csv(dataset_dir, encoding='ISO-8859-1')

enhanced_filter_dict = {}

for key in attr_value_dict.keys():
    enhanced_filter_dict[key] = {}
    print(f"处理属性: {key}")
    
    # 获取多值属性统计
    multi_stats = get_multi_value_statistics(df, key, multi_value_attributes)
    if multi_stats["is_multi_value"]:
        print(f"  多值属性 - 唯一值: {multi_stats['unique_values']}, 平均每行值数: {multi_stats['avg_values_per_row']:.2f}")

for key, value in tqdm(enhanced_filter_dict.items(), desc="构建Filter"):
    condition_dict = {}

    ###### 非数值属性，只取等于操作 ######
    if key in non_numerical_attr and key not in formatted_attr:
        for possible_value in attr_value_dict[key]:
            # 使用增强的比较函数
            def enhanced_equal_compare(cell):
                return enhanced_non_numerical_equal_to_with_split(
                    cell, possible_value, multi_value_attributes
                )

            result = df[key].apply(enhanced_equal_compare)
            result_indices = df[result].index.tolist()
            
            # 为多值属性添加额外的信息
            condition_key = f"=='{possible_value}'"
            
            condition_dict[condition_key] = result_indices
            
            # 记录统计信息
            if len(result_indices) > 0:
                selectivity = len(result_indices) / len(df)
                print(f"    {possible_value}: {len(result_indices)} 行 (选择率: {selectivity:.3f})")

    ###### 数值属性和日期属性的处理保持不变 ######
    elif key in numerical_attr or key in formatted_attr:
        for possible_value in attr_value_dict[key]:
            
            result_indices_less = []
            result_indices_greater = []
            result_indices_equal = []
            result_indices_less_equal = []
            result_indices_greater_equal = []
            
            # 日期属性处理
            if key in formatted_attr:
                if key == "Birth_date" or key == "Death_date":
                    result_index_less = df[key].apply(date_less_than, condition=possible_value)
                    result_index_greater = df[key].apply(date_greater_than, condition=possible_value)
                    result_index_equal = df[key].apply(date_equal_to, condition=possible_value)
                    result_index_less_equal = df[key].apply(date_less_equal, condition=possible_value)
                    result_index_greater_equal = df[key].apply(date_greater_equal, condition=possible_value)
                    
                    result_indices_less = df[result_index_less].index.tolist()
                    result_indices_greater = df[result_index_greater].index.tolist()
                    result_indices_equal = df[result_index_equal].index.tolist()
                    result_indices_less_equal = df[result_index_less_equal].index.tolist()
                    result_indices_greater_equal = df[result_index_greater_equal].index.tolist()
            
            # 纯数值属性处理
            elif key in non_formatted_attr:
                result_index_less = df[key].apply(number_less_than, condition=possible_value)
                result_index_greater = df[key].apply(number_greater_than, condition=possible_value)
                result_index_equal = df[key].apply(number_equal_to, condition=possible_value)
                result_index_less_equal = df[key].apply(number_less_equal, condition=possible_value)
                result_index_greater_equal = df[key].apply(number_greater_equal, condition=possible_value)
                
                result_indices_less = df[result_index_less].index.tolist()
                result_indices_greater = df[result_index_greater].index.tolist()
                result_indices_equal = df[result_index_equal].index.tolist()
                result_indices_less_equal = df[result_index_less_equal].index.tolist()
                result_indices_greater_equal = df[result_index_greater_equal].index.tolist()

            # 添加所有5种比较操作
            condition_dict[f"<{possible_value}"] = result_indices_less
            condition_dict[f">{possible_value}"] = result_indices_greater
            condition_dict[f"=={possible_value}"] = result_indices_equal
            condition_dict[f"<={possible_value}"] = result_indices_less_equal
            condition_dict[f">={possible_value}"] = result_indices_greater_equal

    ###### 特殊处理Century属性（保持原有逻辑）######
    if key == "Century":
        for possible_value in attr_value_dict[key]:
            result_index_less = df[key].apply(century_less_than, condition=possible_value)
            result_index_greater = df[key].apply(century_greater_than, condition=possible_value)
            result_index_less_equal = df[key].apply(century_less_equal, condition=possible_value)
            result_index_greater_equal = df[key].apply(century_greater_equal, condition=possible_value)
            
            def century_equal_to(value, condition):
                if pd.isna(value) or pd.isna(condition):
                    return False
                return value == condition
            
            result_index_equal = df[key].apply(century_equal_to, condition=possible_value)
            
            result_indices_less = df[result_index_less].index.tolist()
            result_indices_greater = df[result_index_greater].index.tolist()
            result_indices_equal = df[result_index_equal].index.tolist()
            result_indices_less_equal = df[result_index_less_equal].index.tolist()
            result_indices_greater_equal = df[result_index_greater_equal].index.tolist()
            
            condition_dict[f"<'{possible_value}'"] = result_indices_less
            condition_dict[f">'{possible_value}'"] = result_indices_greater
            condition_dict[f"=='{possible_value}'"] = result_indices_equal
            condition_dict[f"<='{possible_value}'"] = result_indices_less_equal
            condition_dict[f">='{possible_value}'"] = result_indices_greater_equal
    
    enhanced_filter_dict[key] = condition_dict

# 保存增强的Filter字典
with open("./filter_dict.json", 'w') as f:
    json.dump(enhanced_filter_dict, f, ensure_ascii=False, indent=2)

print("增强的Filter字典构建完成！")

print(f"\n文件已保存: ./filter_dict.json")

目前问题："Birth_date": {}, "Death_date": {}, "Influenced_by": {}, "Influenced_on": {}, "Teaching": {}, "Image_genre": {}是空值

其中，Influence_by，Influence_on，Artwork_URL这三个属性值缺失，要处理一下数值的问题

需要处理的是：有||值，图片和文本属性要区分，

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- ```balanced_sample()```：控制均匀采样 Filter
     - 保证每个属性出现的次数是平衡的

In [None]:
###### 均匀采样函数 ######
def balanced_sample(filters, sample_size=10, random_seed=None):
    if random_seed is not None:
        random.seed(random_seed)
    
    col_to_filters = defaultdict(list)
    for filter_item in filters:
        col = filter_item[0]
        col_to_filters[col].append(filter_item)
    
    unique_cols = list(col_to_filters.keys())
    num_cols = len(unique_cols)
    
    if num_cols == 0:
        return []
    
    base_num = sample_size // num_cols
    remainder = sample_size % num_cols
    
    col_sample_counts = {col: base_num for col in unique_cols}

    for col in random.sample(unique_cols, remainder):
        col_sample_counts[col] += 1
    
    sampled_filters = []
    overflow = 0
    
    for col in unique_cols:
        available = len(col_to_filters[col])
        required = col_sample_counts[col]
        if available >= required:
            sampled = random.sample(col_to_filters[col], required)
            sampled_filters.extend(sampled)
        else:
            sampled = col_to_filters[col]
            sampled_filters.extend(sampled)
            overflow += (required - available)
    
    if overflow > 0:
        remaining_cols = [col for col in unique_cols if len(col_to_filters[col]) > col_sample_counts[col]]
        while overflow > 0 and remaining_cols:
            for col in remaining_cols.copy():
                available = len(col_to_filters[col])
                current = col_sample_counts[col]
                if available > current:
                    sampled = random.sample(
                        list(set(col_to_filters[col]) - set(sampled_filters)), 1
                    )
                    sampled_filters.extend(sampled)
                    col_sample_counts[col] += 1
                    overflow -= 1
                    if overflow == 0:
                        break
                else:
                    remaining_cols.remove(col)

    return sampled_filters[:sample_size]

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- 枚举所有可能的 Filter 排列组合

In [None]:
with open("./filter_dict.json", 'r') as f:
     filter_dict = json.load(f)

filters = []
for col, conditions in filter_dict.items():
     for cond, indices in conditions.items():
          filters.append((col, cond, set(indices)))

# filters = random.sample(filters, 20)
filters = balanced_sample(filters, sample_size=20, random_seed=42)

all_combinations = []
for n in tqdm(range(1, max_filters + 1)):
     all_combinations.extend(itertools.permutations(filters, n))

with open("./all_combinations.pkl", "wb") as f:
    pickle.dump(all_combinations, f)

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- ```select_combinations_with_ratio()```：按不同 Filter 数量的比例构造 WHERE 语句
- 1 - 5 个 Filter 的比例为：2 : 3 : 3 : 1 : 1

In [None]:
def select_combinations_with_ratio(all_combinations, total=1000, ratio=[2, 3, 3, 1, 1], min_len=1, max_len=5):

    length_to_combinations = defaultdict(list)
    for combo in all_combinations:
        l = len(combo)
        if min_len <= l <= max_len:
            length_to_combinations[l].append(combo)

    total_ratio = sum(ratio)
    
    samples_per_length = {}
    remaining_total = total
    
    for i, l in enumerate(range(min_len, max_len + 1)):
        expected_samples = int(total * ratio[i] / total_ratio)
        available_samples = len(length_to_combinations[l])
        if available_samples < expected_samples:
            samples_per_length[l] = available_samples
            remaining_total -= available_samples
            ratio[i] = 0
        else:
            samples_per_length[l] = expected_samples
            remaining_total -= expected_samples
    
    if remaining_total > 0:
        for i, l in enumerate(range(min_len, max_len + 1)):
            if ratio[i] > 0:
                possible_to_add = len(length_to_combinations[l]) - samples_per_length[l]
                if possible_to_add > 0:
                    additional_samples = min(possible_to_add, remaining_total)
                    samples_per_length[l] += additional_samples
                    remaining_total -= additional_samples
                if remaining_total == 0:
                    break

    if remaining_total != 0:
        raise ValueError(f"The number of combinations cannot meet the requirements, with {remaining_total} combinations remaining.")
    
    selected = []
    for l in range(min_len, max_len + 1):
        available = length_to_combinations[l]
        required = samples_per_length[l]
        selected += random.sample(available, required)
    
    random.shuffle(selected)
    return selected

#### 枚举所有 Filter 的可能排列组合，构造 WHERE 语句
- 析取 + 合取 + 析取与合取混合
- 去掉等效表达式
- 每个查询的结果表至少包含 ```min_rows``` 行 

In [None]:
with open("./all_combinations.pkl", "rb") as f:
     all_combinations = pickle.load(f)

def normalize_expression(expression):

    def sort_expression(expr):
        if " AND " in expr:
            parts = sorted(expr.strip("()").split(" AND "))
            return f"({' AND '.join(parts)})"
        elif " OR " in expr:
            parts = sorted(expr.strip("()").split(" OR "))
            return f"({' OR '.join(parts)})"
        return expr
    
    if " AND " in expression or " OR " in expression:
        return sort_expression(expression)
    return expression

valid_where = []
seen_expressions = set()

# sampled_combinations = random.sample(all_combinations, 1000)
sampled_combinations = select_combinations_with_ratio(all_combinations, total=1000, ratio=[2, 3, 3, 1, 1], min_len=1, max_len=max_filters)

for combo in tqdm(sampled_combinations):
    condition_sets = [set(item[2]) for item in combo]
    predicates = [f"{col}{cond}" for col, cond, _ in combo]

    for op in itertools.product(["and", "or"], repeat=len(condition_sets) - 1):
    # for op in itertools.product(["and"], repeat=len(condition_sets) - 1):

        current_sets = [condition_sets[0]]
        current_predicates = [predicates[0]]
        
        for i, logic in enumerate(op):
            if logic == "and":
                current_sets[-1] &= condition_sets[i + 1]
                current_predicates[-1] = f"({current_predicates[-1]} AND {predicates[i + 1]})"
            elif logic == "or":
                current_sets.append(condition_sets[i + 1])
                current_predicates.append(predicates[i + 1])
        
        final_result = set().union(*current_sets)
        if len(final_result) >= min_rows:
            expression = " OR ".join(current_predicates)

            normalized_expression = normalize_expression(expression)
            if normalized_expression not in seen_expressions:
                seen_expressions.add(normalized_expression)

                combo_list = [[c[0], c[1]] for c in combo]
                query_dict = {
                    "WHERE Indices": list(final_result),
                    "WHERE Total Rows": len(final_result),
                    "Combination": combo_list,
                    "Operators": list(op),
                    "WHERE": expression
                }
                valid_where.append(query_dict)
                # if len(valid_where) > 1000:
                #         with open("/data/sunzhaoze/benchmark/valid_WHERE.json", 'a') as f:
                #             json.dump(valid_where, f, ensure_ascii=False, indent=4, separators=(',', ': '))
                #         valid_where.clear()

    if valid_where:
        with open(valid_where_output_dir, 'w') as f:
            json.dump(valid_where, f, ensure_ascii=False)

#### 定义表头 SCHEMA
- ```create_schema()```：SCHEMA 定义函数，数据共三种类型：VARCHAR(255) | DATE | FLOAT

In [None]:
def create_schema(attr_desc_dict, query_attr_list):
     schema = {}
     difference = set(attr_desc_dict.keys()) - set(query_attr_list)
     redundant_attr_list = random.sample(list(difference), random.randint(0, 4))
     schema_attr_list = query_attr_list + redundant_attr_list
     for key in schema_attr_list:
          if key in non_numerical_attr:
               schema[key] = ["VARCHAR(255)", attr_desc_dict[key]]
          elif key in formatted_attr:
               schema[key] = ["DATE", attr_desc_dict[key]]
          elif key in non_formatted_attr:
               schema[key] = ["FLOAT", attr_desc_dict[key]]
     return schema

#### json 自定义格式化输出

In [None]:
def custom_json_dump(data, filename):
    json_str = json.dumps(data, ensure_ascii=False)
    json_str = json_str.replace(", {", ",\n{").replace("{\"", "{\n\t\"").replace("]},\n", "]\n},\n").replace("e},\n", "e\n},\n").replace("}},\n", "}\n},\n")
    json_str = re.sub(r',\s*"([^"]+)":', r', \n\t"\1":', json_str)
    with open(filename, 'w', encoding='utf-8') as f:
        f.write(json_str)

#### 构建 SELECT | FROM

In [None]:
def create_sf_queries():
    valid_sf = []
    
    for i in tqdm(range(sample_sf)):
        # 随机选择1到max_select个列
        selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
        
        # 创建查询字典
        query_dict = {
            "Type": "SF",
            "SELECT": selected_columns,
            "WHERE Indices": list(range(len(df))),  # 包含所有行
            "WHERE Total Rows": len(df),            # 总行数
            "Combination": [],                      # 没有Filter组合
            "Operators": [],                        # 没有操作符
            "WHERE": "None"                         # 没有WHERE条件
        }
        
        # 创建SCHEMA
        query_attr_list = selected_columns.copy()
        query_dict["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
        
        valid_sf.append(query_dict)
    
    return valid_sf
valid_sf = create_sf_queries()
custom_json_dump(valid_sf, "./SELECT_FROM.json")


#### 构建 SELECT | FROM | WHERE

In [None]:
with open(valid_where_output_dir, 'r') as f:
     valid_sfw = json.load(f)

for i in tqdm(range(0, len(valid_sfw))):
    selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
    valid_sfw[i]["Type"] = "SFW"
    valid_sfw[i]["SELECT"] = selected_columns
    query_attr_list = [i for i in selected_columns]
    for k in valid_sfw[i]["Combination"]:
         if k[0] not in query_attr_list:
             query_attr_list += [k[0]]
    valid_sfw[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)

sampled_valid_sfw = random.sample(valid_sfw, sample_sfw)

# with open("./SELECT_FROM_WHERE.json", 'w') as f:
#      json.dump(sampled_valid_sfw, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfw, "./SELECT_FROM_WHERE.json")

#### 构建 SELECT | FROM | WHERE | GROUP BY
- 返回多个表，每个组一个表

In [None]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwg = []
for i in tqdm(range(0, len(valid_where))):
     # while True:
     #      selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_groupby))
     #      if not (set(selected_columns) & set(formatted_attr)):
     #           break
     selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
     valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
     groupby_columns = random.sample(valid_groupby_attrs, 1)
     row_indices = valid_where[i].get("WHERE Indices", [])
     # 过滤掉无效的索引
     row_indices = [idx for idx in row_indices if idx < len(df)]
     if not row_indices:  # 如果没有有效索引，跳过
          continue
     filtered_df = df.loc[row_indices]
     grouped = filtered_df.groupby(groupby_columns)
     remaining_rows = grouped.size().reset_index()
     if len(remaining_rows) > (min_rows // 2):
          valid_where[i]["Type"] = "SFWG"
          valid_where[i]["SELECT"] = selected_columns + groupby_columns if groupby_columns not in selected_columns else selected_columns
          valid_where[i]["GROUP BY Total Rows (Groups)"] = len(remaining_rows)
          valid_where[i]["GROUP BY"] = groupby_columns
          query_attr_list = [i for i in groupby_columns]
          for k in valid_where[i]["Combination"]:
               if k[0] not in query_attr_list:
                    query_attr_list += [k[0]]
          valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
          valid_sfwg.append(valid_where[i])

sampled_valid_sfwg = random.sample(valid_sfwg, sample_sfwg)

# with open("./SELECT_FROM_WHERE_GROUPBY.json", 'w') as f:
#      json.dump(sampled_valid_sfwg, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwg, "./SELECT_FROM_WHERE_GROUPBY.json")

#### 构建 SELECT | FROM | WHERE | AGGREGATION
- 支持的聚合函数：SUM | MAX | MIN | AVG | COUNT
- TODO：加入语义聚合
- 只在一个属性上进行聚合

In [None]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwa = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
max_COUNT = len(valid_where) // 4

for i in tqdm(range(0, len(valid_where))):
     numerical = False
     while True:
          if max_COUNT > 0:
               selected_aggregation_column = random.sample(list(attr_value_dict.keys()), 1)
               if selected_aggregation_column not in formatted_attr:
                    break
          else:
               selected_aggregation_column = random.sample(list(non_formatted_attr), 1)
               break
          
     if selected_aggregation_column[0] in non_formatted_attr:
          function = random.choice(aggregation_functions[1:])
          numerical = True
     else:
          max_COUNT -= 1
          function = 'COUNT'
     if function == "COUNT" and random.uniform(0, 1) > 0.5:
          selected_aggregation_column = ["*"]
     valid_where[i]["Type"] = "SFWA"
     valid_where[i]["SELECT"] = [f"{function}({selected_aggregation_column[0]})"]
     query_attr_list = [i for i in selected_aggregation_column]
     for k in valid_where[i]["Combination"]:
          if k[0] not in query_attr_list:
               query_attr_list += [k[0]]
     valid_where[i]["AGGREGATION"] = selected_aggregation_column
     valid_where[i]["AGGREGATION Function"] = function
     valid_where[i]["Numerical"] = numerical
     valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
     valid_sfwa.append(valid_where[i])

sampled_valid_sfwa = random.sample(valid_sfwa, sample_sfwa)

# with open("./SELECT_FROM_WHERE_AGGREGATION.json", 'w') as f:
#      json.dump(sampled_valid_sfwa, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwa, "./SELECT_FROM_WHERE_AGGREGATION.json")

#### 构建 SELECT | FROM | GROUP BY | AGGREGATION

In [None]:
valid_sfag = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
sample_sfag = 20  # 定义生成数量
max_COUNT = sample_sfag // 4

numerical_columns = numerical_attr 

for i in tqdm(range(sample_sfag)):
    # 选择分组列（排除多值属性）
    valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
    selected_groupby_columns = random.sample(valid_groupby_attrs, 1)
    
    # 选择聚合函数和列
    if max_COUNT > 0:
        # 随机决定是否使用COUNT函数
        if random.uniform(0, 1) < 0.25:  # 25%概率使用COUNT
            max_COUNT -= 1
            function = 'COUNT'
            numerical = False
            # COUNT可以应用于任何列
            available_columns = list(set(attr_value_dict.keys()) - set(selected_groupby_columns))
            selected_aggregation_column = random.sample(available_columns, 1)
        else:
            # 使用数值型聚合函数，只从数值型列中选择
            function = random.choice(aggregation_functions[1:])  # 排除COUNT
            numerical = True
            available_numerical_columns = list(set(numerical_columns) - set(selected_groupby_columns))
            if available_numerical_columns:
                selected_aggregation_column = random.sample(available_numerical_columns, 1)
            else:
                function = 'COUNT'
                numerical = False
                available_columns = list(set(attr_value_dict.keys()) - set(selected_groupby_columns))
                selected_aggregation_column = random.sample(available_columns, 1)
    else:
        function = random.choice(aggregation_functions[1:])  # 排除COUNT
        numerical = True
        # 确保选择的列是数值型的且不是分组列
        available_numerical_columns = list(set(numerical_columns) - set(selected_groupby_columns))
        if available_numerical_columns:
            selected_aggregation_column = random.sample(available_numerical_columns, 1)
        else:
            # 如果没有可用的数值型列，跳过此次迭代
            continue
    
    # COUNT函数可以使用*
    if function == "COUNT" and random.uniform(0, 1) > 0.5:
        selected_aggregation_column = ["*"]
    
    query_dict = {
        "Type": "SFAG",
        "GROUP BY": selected_groupby_columns,
        "SELECT": selected_groupby_columns + [f"{function}({selected_aggregation_column[0]})"],
        "WHERE Indices": list(range(len(df))),  # 包含所有行
        "WHERE Total Rows": len(df),            # 总行数
        "Combination": [],                      # 没有Filter组合
        "Operators": [],                        # 没有操作符
        "WHERE": "None",                        # 没有WHERE条件
        "AGGREGATION": selected_aggregation_column,
        "AGGREGATION Function": "AVG" if function == "MEAN" else function,
        "Numerical": numerical
    }
    
    # 创建SCHEMA
    query_attr_list = selected_aggregation_column + selected_groupby_columns
    query_dict["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
    
    valid_sfag.append(query_dict)

custom_json_dump(valid_sfag, "./SELECT_FROM_AGGREGATION_GROUPBY.json")

#### 构建 SELECT | FROM | WHERE | GROUP BY | AGGREGATION

In [None]:
def convert_column_to_float(df, column_name):
    df[column_name] = df[column_name].apply(pd.to_numeric, errors='coerce')
    return df

for i in non_formatted_attr:
    convert_column_to_float(df, i)

In [None]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwga = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
max_COUNT = sample_sfwga // 4

for i in tqdm(range(0, len(valid_where))):
     row_indices = valid_where[i].get("WHERE Indices", [])
     # 过滤掉无效的索引
     row_indices = [idx for idx in row_indices if idx < len(df)]
     if not row_indices:  # 如果没有有效索引，跳过
          continue
     filtered_df = df.loc[row_indices]
     # while True:
     #      selected_groupby_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_groupby))
     #      if not set(selected_groupby_columns) & set(formatted_attr):
     #           break
     # 排除多值属性
     valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
     selected_groupby_columns = random.sample(valid_groupby_attrs, 1)
     grouped = filtered_df.groupby(selected_groupby_columns)
     remaining_rows = grouped.size().reset_index(name='Group Size')

     if len(remaining_rows) > (min_rows // 2):
          while True:
               if max_COUNT > 0:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    if selected_aggregation_column not in formatted_attr:
                         break
               else:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    break
               
          if selected_aggregation_column[0] in non_formatted_attr:
               function = random.choice(aggregation_functions[1:])
               numerical = True
          else:
               max_COUNT -= 1
               function = 'COUNT'
               numerical = False
          
          if function == "COUNT" and random.uniform(0, 1) > 0.5:
               selected_aggregation_column = ["*"]
          pandas_func = 'mean' if function == 'AVG' else function.lower()
          agg_result = grouped[selected_aggregation_column].agg(pandas_func) if numerical else grouped.size()
          valid_where[i]["Type"] = "SFWGA"
          valid_where[i]["GROUP BY"] = selected_groupby_columns
          valid_where[i]["GROUP BY Total Rows"] = len(remaining_rows)
          valid_where[i]["SELECT"] = selected_groupby_columns + [f"{function}({selected_aggregation_column[0]})"]

          valid_where[i]["AGGREGATION"] = selected_aggregation_column
          if function == "MEAN":
               valid_where[i]["AGGREGATION Function"] = "AVG"
          else:
               valid_where[i]["AGGREGATION Function"] = function
          valid_where[i]["Numerical"] = numerical
          query_attr_list = [i for i in selected_aggregation_column]
          for k in valid_where[i]["Combination"]:
               if k[0] not in query_attr_list:
                    query_attr_list += [k[0]]
          for k in selected_groupby_columns:
               if k not in query_attr_list:
                    query_attr_list += [k]
          valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
          valid_sfwga.append(valid_where[i])

sampled_valid_sfwga = random.sample(valid_sfwga, sample_sfwga)

# with open("./SELECT_FROM_WHERE_GROUPBY_AGGREGATION.json", 'w') as f:
#      json.dump(sampled_valid_sfwga, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwga, "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION.json")

#### 构建 SELECT | FROM | WHERE | TOP-K

In [None]:
with open(valid_where_output_dir, 'r') as f:
    valid_where = json.load(f)

valid_sfwt = []
order_options = ['ASC', 'DESC']

for i in tqdm(range(0, len(valid_where))):
    row_indices = valid_where[i].get("WHERE Indices", []) # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices] # 过滤掉无效的索引 row_indices = [idx for idx in row_indices if idx < len(df)] if not row_indices:  # 如果没有有效索引，跳过     continue filtered_df = df.loc[row_indices]
    # 过滤掉无效的索引
    row_indices = [idx for idx in row_indices if idx < len(df)]
    if not row_indices:  # 如果没有有效索引，跳过
        continue
    filtered_df = df.loc[row_indices]
    while True:
        selected_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_select))
        if set(selected_columns) & set(non_formatted_attr):
            order_column = random.choice(list(set(selected_columns) & set(non_formatted_attr)))
            break

    order_type = random.choice(order_options)

    limit_value = min(random.sample(limit_list, 1)[0], len(filtered_df) // 2)
    valid_where[i]["Type"] = "SFWT"
    valid_where[i]["SELECT"] = selected_columns
    valid_where[i]["LIMIT"] = limit_value
    valid_where[i]["ORDER BY"] = [order_column, order_type]
    query_attr_list = [i for i in selected_columns]
    for k in valid_where[i]["Combination"]:
         if k[0] not in query_attr_list:
             query_attr_list += [k[0]]
    valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)

    valid_sfwt.append(valid_where[i])

sampled_valid_sfwt = random.sample(valid_sfwt, sample_sfwt)

# with open("./SELECT_FROM_WHERE_TOPK.json", 'w') as f:
#     json.dump(sampled_valid_sfwt, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwt, "./SELECT_FROM_WHERE_TOPK.json")

#### 构建 SELECT | FROM | WHERE | GROUP BY | AGGREGATION | TOP-K

In [None]:
with open(valid_where_output_dir, 'r') as f:
     valid_where = json.load(f)

valid_sfwgat = []
aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
order_options = ['ASC', 'DESC']
max_COUNT = sample_sfwgat // 4

for i in tqdm(range(0, len(valid_where))):
     row_indices = valid_where[i].get("WHERE Indices", [])
     # 过滤掉无效的索引
     row_indices = [idx for idx in row_indices if idx < len(df)]
     if not row_indices:  # 如果没有有效索引，跳过
          continue
     filtered_df = df.loc[row_indices]

     # while True:
     #      selected_groupby_columns = random.sample(list(attr_value_dict.keys()), random.randint(1, max_groupby))
     #      if not set(selected_groupby_columns) & set(formatted_attr):
     #           break
     valid_groupby_attrs = [attr for attr in category_attr if attr not in multi_value_attributes]
     selected_groupby_columns = random.sample(valid_groupby_attrs, 1)
     grouped = filtered_df.groupby(selected_groupby_columns)
     remaining_rows = grouped.size().reset_index(name='Group Size')

     if len(remaining_rows) > (min_rows // 2):
          while True:
               if max_COUNT > 0:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    if selected_aggregation_column not in formatted_attr:
                         break
               else:
                    selected_aggregation_column = random.sample(list(set(attr_value_dict.keys())-set(selected_groupby_columns)), 1)
                    break
          if selected_aggregation_column[0] in non_formatted_attr:
               function = random.choice(aggregation_functions[1:])
               numerical = True
          else:
               max_COUNT -= 1
               function = 'COUNT'
               numerical = False

          if function == "COUNT" and random.uniform(0, 1) > 0.5:
               selected_aggregation_column = ["*"]

          if selected_aggregation_column == ["*"]:
               agg_result = grouped.transform(lambda x: x.notna().sum())
          else:
               agg_result = grouped[selected_aggregation_column[0]].transform(lambda x: x.notna().sum())

          # agg_result = grouped[selected_aggregation_column].agg(function.lower()) if numerical else grouped.size()
          select_set = set(selected_groupby_columns).union(set(selected_aggregation_column))
          order_set = select_set & set(non_formatted_attr)
          if not order_set:
               continue
          order_column = random.choice(list(order_set))
          order_type = random.choice(order_options)
          limit_value = min(random.sample(limit_list, 1)[0], len(agg_result // 2))
          valid_where[i]["Type"] = "SFWGAT"
          valid_where[i]["GROUP BY"] = selected_groupby_columns
          valid_where[i]["GROUP BY Total Rows"] = len(remaining_rows)
          valid_where[i]["SELECT"] = selected_groupby_columns + [f"{function}({selected_aggregation_column[0]})"]
          valid_where[i]["AGGREGATION"] = selected_aggregation_column
          if function == "MEAN":
               valid_where[i]["AGGREGATION Function"] = "AVG"
          else:
               valid_where[i]["AGGREGATION Function"] = function
          valid_where[i]["Numerical"] = numerical
          valid_where[i]["LIMIT"] = limit_value
          valid_where[i]["ORDER BY"] = [order_column, order_type]
          query_attr_list = [i for i in selected_aggregation_column]
          for k in valid_where[i]["Combination"]:
               if k[0] not in query_attr_list:
                    query_attr_list += [k[0]]
          for k in selected_groupby_columns:
               if k not in query_attr_list:
                    query_attr_list += [k]
          valid_where[i]["SCHEMA"] = create_schema(attr_desc_dict, query_attr_list)
                    
          valid_sfwgat.append(valid_where[i])

sampled_valid_sfwgat = random.sample(valid_sfwgat, sample_sfwgat)

# with open("./SELECT_FROM_WHERE_GROUPBY_AGGREGATION_TOPK.json", 'w') as f:
#     json.dump(sampled_valid_sfwgat, f, ensure_ascii=False)
custom_json_dump(sampled_valid_sfwgat, "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION_TOPK.json")

In [None]:
import json

#### 根据构造的查询结构生成对应的SQL语句
def generate_sql_query(query_dict, table_name="Wikiart"):
    """
    根据查询字典生成对应的SQL语句（支持SF类型）
    """
    query_type = query_dict.get("Type", "")
    
    # 构建SELECT子句
    select_clause = "SELECT " + ", ".join(query_dict["SELECT"])
    
    # 构建FROM子句
    from_clause = f"FROM {table_name}"
    
    # 构建WHERE子句（SF类型没有WHERE）
    where_clause = ""
    if "WHERE" in query_dict and query_dict["WHERE"] != "None":
        where_clause = f"WHERE {query_dict['WHERE']}"
    
    # 构建GROUP BY子句
    group_by_clause = ""
    if "GROUP BY" in query_dict:
        group_by_clause = f"GROUP BY {', '.join(query_dict['GROUP BY'])}"
    
    # 构建ORDER BY子句
    order_by_clause = ""
    if "ORDER BY" in query_dict:
        order_col, order_type = query_dict["ORDER BY"]
        order_by_clause = f"ORDER BY {order_col} {order_type}"
    
    # 构建LIMIT子句
    limit_clause = ""
    if "LIMIT" in query_dict:
        limit_clause = f"LIMIT {query_dict['LIMIT']}"
    
    # 组合完整SQL
    sql_parts = [select_clause, from_clause]
    if where_clause:
        sql_parts.append(where_clause)
    if group_by_clause:
        sql_parts.append(group_by_clause)
    if order_by_clause:
        sql_parts.append(order_by_clause)
    if limit_clause:
        sql_parts.append(limit_clause)
    
    return "\n".join(sql_parts) + ";"

def generate_schema_sql(schema_dict, table_name="Wikiart"):
    """
    根据SCHEMA字典生成建表SQL语句
    """
    create_table = f"CREATE TABLE {table_name} (\n"
    columns = []
    
    for col_name, (data_type, description) in schema_dict.items():
        comment = f" COMMENT '{description}'" if description else ""
        columns.append(f"    {col_name} {data_type}{comment}")
    
    create_table += ",\n".join(columns)
    create_table += "\n);"
    
    return create_table

def save_queries_with_sql(input_files, output_dir="./sql_queries/"):
    """
    读取所有生成的查询文件，为每个查询生成对应的SQL语句并保存（支持SF类型）
    修正：使用正确的换行符格式
    """
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    for file_path in input_files:
        if not os.path.exists(file_path):
            print(f"文件不存在: {file_path}")
            continue
            
        with open(file_path, 'r', encoding='utf-8') as f:
            queries = json.load(f)
        
        # 为每个查询生成SQL
        for i, query in enumerate(queries):
            # 生成建表SQL
            schema_sql = generate_schema_sql(query["SCHEMA"])
            
            # 生成查询SQL
            query_sql = generate_sql_query(query)
            
            # 计算SELECT数量
            select_count = len(query.get("SELECT", []))
            
            # 计算Filter数量 - 使用Combination字段获取最准确的数量
            filter_count = len(query.get("Combination", []))
            
            # 组合完整SQL - 修正：使用正确的换行符 \n 而不是 \\n
            complete_sql = f"-- Query {i+1} ({query['Type']})\n"
            complete_sql += f"-- Total Rows: {query.get('WHERE Total Rows', 'N/A')}\n"
            complete_sql += f"-- SELECT: {select_count}\n"
            complete_sql += f"-- FILTER: {filter_count}\n\n"
            complete_sql += schema_sql + "\n\n"
            complete_sql += query_sql + "\n"
            complete_sql += "-" * 50 + "\n\n"
            
            query["SQL"] = {
                "schema": schema_sql,
                "query": query_sql,
                "complete": complete_sql
            }
        
        # 保存包含SQL的查询文件
        output_file = output_dir + os.path.basename(file_path).replace('.json', '_with_sql.json')
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(queries, f, ensure_ascii=False, indent=2)
        
        # 单独保存纯SQL文件
        sql_file = output_dir + os.path.basename(file_path).replace('.json', '.sql')
        with open(sql_file, 'w', encoding='utf-8') as f:
            for query in queries:
                f.write(query["SQL"]["complete"])
        
        print(f"已生成SQL文件: {sql_file}")

# 使用示例：为所有查询类型生成SQL
query_files = [
    "./SELECT_FROM.json",
    "./SELECT_FROM_WHERE.json",
    "./SELECT_FROM_WHERE_TOPK.json", 
    "./SELECT_FROM_WHERE_GROUPBY.json",
    "./SELECT_FROM_WHERE_AGGREGATION.json",
    "./SELECT_FROM_AGGREGATION_GROUPBY.json",
    "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION.json",
    "./SELECT_FROM_WHERE_GROUPBY_AGGREGATION_TOPK.json"
]

# 生成所有SQL查询
save_queries_with_sql(query_files)

#### 固定Filter

In [None]:
#### SQL文件生成器 - 图文结合的Filter模板版本 (修复换行符)
import json
import random
import os
from tqdm import tqdm

# 改进的语义模板定义 - 结合图片和文字属性
TEMPLATES = [
    # 1个Filter (2个类别) - 可以是纯文字或纯图片属性
    {
        "name": "nationality_focus", 
        "description": "特定国籍艺术家分析", 
        "filters": ["Nationality"],
        "use_case": "研究某个国家的艺术家特征和创作倾向",
        "semantic_columns": ["Nationality", "Name", "Age", "Field"]
    },
    {
        "name": "style_focus", 
        "description": "特定艺术风格研究", 
        "filters": ["Style"],
        "use_case": "分析特定艺术风格的视觉特征和表现手法",
        "semantic_columns": [ "Name",  "Theme"]
    },
    
    # 2个Filter (3个类别) - 必须包含至少一个图片属性和一个文字属性
    {
        "name": "nationality_style", 
        "description": "国籍与艺术风格关联分析", 
        "filters": ["Nationality", "Style"],
        "use_case": "探索不同国家艺术家偏好的艺术风格，分析地域文化对艺术表现的影响",
        "semantic_columns": ["Nationality",  "Name", "Birth_continent"]
    },
    {
        "name": "century_color", 
        "description": "历史时期与色彩运用研究", 
        "filters": ["Century", "Color"],
        "use_case": "分析不同历史时期艺术作品的色彩特征和演变趋势",
        "semantic_columns": ["Century",  "Name", "Age", "Style"]
    },
    {
        "name": "age_theme", 
        "description": "艺术家年龄与作品主题关系", 
        "filters": ["Age", "Theme"],
        "use_case": "研究艺术家人生阶段对创作主题选择的影响",
        "semantic_columns": ["Age",  "Name", "Century"]
    },
    
    # 3个Filter (3个类别) - 必须包含至少一个图片属性和一个文字属性
    {
        "name": "nationality_century_style", 
        "description": "国家-时代-风格综合分析", 
        "filters": ["Nationality", "Century", "Style"],
        "use_case": "深度分析特定时期特定国家的主流艺术风格和特征",
        "semantic_columns": ["Nationality", "Century",  "Name", "Birth_continent"]
    },
    {
        "name": "marriage_age_tone", 
        "description": "婚姻状态与艺术表达情感分析", 
        "filters": ["Marriage", "Age", "Tone"],
        "use_case": "探索艺术家的人生状态（婚姻、年龄）对作品情感色调的影响",
        "semantic_columns": ["Marriage", "Age",  "Name", "Theme"]
    },
    {
        "name": "continent_field_composition", 
        "description": "地域文化与艺术构图研究", 
        "filters": ["Birth_continent", "Field", ],
        "use_case": "分析不同大洲的文化背景对艺术领域和构图方式的影响",
        "semantic_columns": ["Birth_continent", "Field",  "Name", "Nationality"]
    },
    
    # 4个Filter (1个类别) - 必须包含至少一个图片属性和一个文字属性
    {
        "name": "european_painting_masters", 
        "description": "欧洲绘画大师深度研究", 
        "filters": ["Birth_continent", "Field", "Age", "Style"],
        "use_case": "专门研究欧洲绘画领域的成熟艺术家，分析其风格特征和创作规律",
        "semantic_columns": ["Birth_continent", "Field", "Age", "Style", "Name", "Awards"]
    },
    
    # 5个Filter (1个类别) - 必须包含至少一个图片属性和一个文字属性  
    {
        "name": "elite_artist_profile", 
        "description": "精英艺术家全方位画像", 
        "filters": ["Nationality", "Birth_continent", "Age", "Style", "Color"],
        "use_case": "构建顶级艺术家的完整画像，包括地理背景、人生阶段、艺术风格和色彩偏好",
        "semantic_columns": ["Nationality", "Birth_continent", "Age", "Style",  "Name", "Awards"]
    }
]

QUERY_TYPES = ["SFW", "SFWT", "SFWG", "SFWA", "SFAG", "SFWGA", "SFWGAT"]

class SQLGenerator:
    def __init__(self, min_result_rows=5):
        self.df = globals()['df']
        self.attr_dict = globals()['attr_value_dict']
        self.numerical_attr = globals()['numerical_attr']
        self.non_numerical_attr = globals()['non_numerical_attr']
        self.category_attr = globals()['category_attr']
        self.formatted_attr = globals()['formatted_attr']
        self.attr_desc_dict = globals()['attr_desc_dict']
        self.min_rows = min_result_rows
        
        # 属性分类
        self.pure_text_attributes = {"Nationality", "Birth_date", "Death_date", "Age", "Century", 
                                   "Zodiac", "Birth_country", "Birth_city", "Birth_continent", 
                                   "Death_country", "Death_city", "Field", "Marriage", 
                                   "Art_institution", "Teaching", "Awards"}
        
        self.pure_image_attributes = {"Style", "Image_genre", "Color", "Tone", "Composition"}
        
        self.common_attributes = {"Name"}
        
        # 分析数据分布用于放宽策略
        self.stats = self._analyze_data()
        
    def _analyze_data(self):
        """分析数据分布"""
        stats = {}
        for attr in self.attr_dict.keys():
            if attr in self.non_numerical_attr:
                try:
                    value_counts = self.df[attr].value_counts()
                    stats[attr] = {
                        "type": "categorical", 
                        "top_values": value_counts.head(5).to_dict()
                    }
                except:
                    stats[attr] = {"type": "categorical", "top_values": {}}
        return stats
    
    def get_filter_value(self, attr, relaxation_level=0):
        """从attr_value_dict获取值和条件，支持放宽策略"""
        values = self.attr_dict.get(attr, [])
        if not values:
            return None, None
        
        if attr in self.numerical_attr:
            val = random.choice(values)
            
            if relaxation_level == 0:
                op = random.choice(["==", ">", ">=", "<", "<="])
            elif relaxation_level == 1:
                op = random.choice([">=", "<=", ">", "<", ">=", "<="])
            elif relaxation_level == 2:
                op = random.choice([">=", "<="])
            else:
                op = random.choice([">=", "<="])
                
            if relaxation_level >= 1:
                if attr == "Age":
                    sorted_vals = sorted(values)
                    if op in [">=", ">"]:
                        val = random.choice(sorted_vals[:len(sorted_vals)//3])
                    elif op in ["<=", "<"]:
                        val = random.choice(sorted_vals[2*len(sorted_vals)//3:])
                elif attr == "Awards":
                    if op in [">=", ">"]:
                        min_awards = min([v for v in values if isinstance(v, (int, float))])
                        val = min_awards if relaxation_level >= 2 else random.choice([v for v in values if v <= min_awards + 1])
                elif attr == "Person_count":
                    if op in [">=", ">"]:
                        val = random.choice([v for v in values if v <= 2])
                    elif op in ["<=", "<"]:
                        val = random.choice([v for v in values if v >= 3])
            
            return val, f"{attr} {op} {val}"
        else:
            if relaxation_level >= 1 and attr in self.stats and self.stats[attr]["type"] == "categorical":
                top_values = list(self.stats[attr]["top_values"].keys())[:3]
                common_values = [v for v in top_values if v in values]
                if common_values:
                    val = random.choice(common_values)
                else:
                    val = random.choice(values)
            else:
                val = random.choice(values)
            return val, f"{attr} == '{val}'"
    
    def apply_filters_with_relaxation(self, filters_config, max_relaxation=3):
        """应用过滤条件，支持自动放宽策略"""
        for relaxation_level in range(max_relaxation + 1):
            result_indices, applied_conditions = self._try_apply_filters(filters_config, relaxation_level)
            if len(result_indices) >= 2:
                return result_indices, applied_conditions, relaxation_level
        return None, None, "FAILED_ALL_RELAXATION"
    
    def _try_apply_filters(self, filters_config, relaxation_level):
        """尝试应用过滤条件，严格保持Filter数量"""
        result_indices = set(range(len(self.df)))
        applied_conditions = []
        
        for attr in filters_config:
            val, condition = self.get_filter_value(attr, relaxation_level)
            if not val or not condition:
                return [], []
                
            try:
                if condition.endswith("'"):
                    mask = self.df[attr].apply(lambda x: 
                        str(x).strip().lower() == str(val).strip().lower() or
                        (isinstance(x, str) and '||' in x and str(val) in x.split('||')))
                elif ">=" in condition:
                    mask = self.df[attr] >= val
                elif ">" in condition:
                    mask = self.df[attr] > val
                elif "<=" in condition:
                    mask = self.df[attr] <= val
                elif "<" in condition:
                    mask = self.df[attr] < val
                elif "==" in condition:
                    mask = self.df[attr] == val
                else:
                    return [], []
                
                new_indices = set(self.df[mask].index)
                result_indices &= new_indices
                applied_conditions.append(condition)
                
            except Exception:
                return [], []
        
        if len(applied_conditions) != len(filters_config):
            return [], []
        
        return list(result_indices), applied_conditions
    
    def create_schema_sql(self, attrs, table_name="Wikiart"):
        """生成建表SQL"""
        schema_parts = []
        for attr in set(attrs):
            if '(' in attr:
                continue
            if attr in self.numerical_attr:
                schema_parts.append(f"    {attr} FLOAT")
            elif attr in self.formatted_attr:
                schema_parts.append(f"    {attr} DATE")
            else:
                schema_parts.append(f"    {attr} VARCHAR(255)")
        
        return f"CREATE TABLE {table_name} (\n" + ",\n".join(schema_parts) + "\n);"
    
    def validate_template_requirements(self, template):
        """验证模板是否符合图文结合要求"""
        filters = set(template["filters"])
        filter_count = len(filters)
        
        if filter_count >= 2:
            has_image_attr = bool(filters & self.pure_image_attributes)
            has_text_attr = bool(filters & self.pure_text_attributes)
            return has_image_attr and has_text_attr
        
        return True
    
    def get_semantic_columns(self, template, qtype):
        """根据模板和查询类型选择语义相关的列"""
        semantic_cols = template.get("semantic_columns", template["filters"])
        
        if qtype == "SFW":
            return semantic_cols[:3]
        elif qtype in ["SFWT", "SFWG"]:
            return semantic_cols[:2]
        elif qtype in ["SFWA", "SFAG", "SFWGA", "SFWGAT"]:
            return semantic_cols[:2]
        
        return semantic_cols[:3]
    
    def get_semantic_group_column(self, template):
        """为模板选择语义相关的分组列"""
        filters = set(template["filters"])
        
        relevant_categories = [attr for attr in self.category_attr if attr in filters]
        if relevant_categories:
            return random.choice(relevant_categories)
        
        semantic_groups = {
            "nationality_focus": ["Nationality", "Birth_continent"],
            "style_focus": ["Style", "Image_genre"],
            "nationality_style": ["Nationality", "Style"],
            "century_color": ["Century", "Style"],
            "age_theme": ["Century", "Marriage"],
            "nationality_century_style": ["Nationality", "Century"],
            "marriage_age_tone": ["Marriage", "Century"],
            "continent_field_composition": ["Birth_continent", "Field"],
            "european_painting_masters": ["Field", "Century"],
            "elite_artist_profile": ["Nationality", "Birth_continent"]
        }
        
        preferred = semantic_groups.get(template["name"], self.category_attr)
        available = [col for col in preferred if col in self.category_attr]
        
        return random.choice(available) if available else random.choice(self.category_attr)
    
    def get_semantic_aggregation_column(self, template):
        """为模板选择语义相关的聚合列"""
        filters = set(template["filters"])
        
        relevant_numerics = [attr for attr in self.numerical_attr if attr in filters]
        if relevant_numerics:
            return random.choice(relevant_numerics)
        
        semantic_aggs = {
            "nationality_focus": ["Age", "Awards"],
            "style_focus": ["Style", "Image_genre"],
            "nationality_style": ["Age", "Awards"],
            "century_color": ["Age", "Person_count"],
            "age_theme": ["Age", "Person_count"],
            "nationality_century_style": ["Age", "Awards"],
            "marriage_age_tone": ["Age", "Person_count"],
            "continent_field_composition": ["Age", "Awards"],
            "european_painting_masters": ["Age", "Awards"],
            "elite_artist_profile": ["Age", "Awards"]
        }
        
        preferred = semantic_aggs.get(template["name"], self.numerical_attr)
        available = [col for col in preferred if col in self.numerical_attr]
        
        return random.choice(available) if available else random.choice(self.numerical_attr)
    
    def generate_query_sql(self, template, qtype, query_id):
        """生成单个查询的SQL，带语义化改进"""
        
        if not self.validate_template_requirements(template):
            return None
        # SFAG 不需要WHERE条件，直接使用所有行
        if qtype == "SFAG":
            indices = list(range(len(self.df)))
            conditions = []
            relaxation_used = 0
        else:
            result = self.apply_filters_with_relaxation(template["filters"])
            
            if result[2] == "FAILED_ALL_RELAXATION":
                return None
            
            indices, conditions, relaxation_used = result
            
            if len(indices) < 2 or not conditions:
                return None
            
        result = self.apply_filters_with_relaxation(template["filters"])
        
        if result[2] == "FAILED_ALL_RELAXATION":
            return None
        
        indices, conditions, relaxation_used = result
        
        if len(indices) < 2 or not conditions:
            return None
        
        select_cols = self.get_semantic_columns(template, qtype)
        schema_attrs = select_cols.copy()
        
        sql_parts = []
        
        if qtype == "SFW":
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
            
        elif qtype == "SFWT":
            numeric_cols = [c for c in select_cols if c in self.numerical_attr]
            if not numeric_cols:
                numeric_col = self.get_semantic_aggregation_column(template)
                select_cols.append(numeric_col)
                schema_attrs.append(numeric_col)
                numeric_cols = [numeric_col]
            
            order_col = random.choice(numeric_cols)
            order_dir = random.choice(["ASC", "DESC"])
            limit_val = random.choice([5, 10, 15, 20])
            
            sql_parts.extend([
                f"SELECT {', '.join(select_cols)}",
                "FROM Wikiart",
                f"WHERE {' AND '.join(conditions)}",
                f"ORDER BY {order_col} {order_dir}",
                f"LIMIT {limit_val}"
            ])
            
        elif qtype == "SFWG":
            group_col = self.get_semantic_group_column(template)
            if group_col not in select_cols:
                select_cols.append(group_col)
                schema_attrs.append(group_col)
            
            sql_parts.extend([
                f"SELECT {', '.join(select_cols)}",
                "FROM Wikiart",
                f"WHERE {' AND '.join(conditions)}",
                f"GROUP BY {group_col}"
            ])
            
        elif qtype == "SFWA":
            func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
            if func == "COUNT" and random.random() < 0.3:
                select_cols = [f"{func}(*)"]
                schema_attrs = template["filters"]
            else:
                agg_col = self.get_semantic_aggregation_column(template)
                select_cols = [f"{func}({agg_col})"]
                schema_attrs = [agg_col] + template["filters"]
            
            sql_parts.append(f"SELECT {', '.join(select_cols)}")
        
        elif qtype == "SFAG":
            group_col = self.get_semantic_group_column(template)
            func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
            
            if func == "COUNT" and random.random() < 0.3:
                select_cols = [group_col, f"{func}(*)"]
                schema_attrs = [group_col] + template["filters"]
            else:
                agg_col = self.get_semantic_aggregation_column(template)
                select_cols = [group_col, f"{func}({agg_col})"]
                schema_attrs = [group_col, agg_col] + template["filters"]
            
            sql_parts.extend([
                f"SELECT {', '.join(select_cols)}",
                "FROM Wikiart",
                f"GROUP BY {group_col}"
            ])
        
        elif qtype == "SFWGA":
            group_col = self.get_semantic_group_column(template)
            func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
            
            if func == "COUNT" and random.random() < 0.3:
                select_cols = [group_col, f"{func}(*)"]
                schema_attrs = [group_col] + template["filters"]
            else:
                agg_col = self.get_semantic_aggregation_column(template)
                select_cols = [group_col, f"{func}({agg_col})"]
                schema_attrs = [group_col, agg_col] + template["filters"]
            
            sql_parts.extend([
                f"SELECT {', '.join(select_cols)}",
                "FROM Wikiart",
                f"WHERE {' AND '.join(conditions)}",
                f"GROUP BY {group_col}"
            ])
            
        elif qtype == "SFWGAT":
            group_col = self.get_semantic_group_column(template)
            func = random.choice(["MAX", "MIN", "AVG", "SUM"])
            agg_col = self.get_semantic_aggregation_column(template)
            order_dir = random.choice(["ASC", "DESC"])
            limit_val = random.choice([5, 10, 15])
            
            select_cols = [group_col, f"{func}({agg_col})"]
            schema_attrs = [group_col, agg_col] + template["filters"]
            
            sql_parts.extend([
                f"SELECT {', '.join(select_cols)}",
                "FROM Wikiart",
                f"WHERE {' AND '.join(conditions)}",
                f"GROUP BY {group_col}",
                f"ORDER BY {func}({agg_col}) {order_dir}",
                f"LIMIT {limit_val}"
            ])
        
        # 添加基本的FROM和WHERE（如果还没有）
        if len(sql_parts) == 1:  # 只有SELECT
            sql_parts.extend([
                "FROM Wikiart",
                f"WHERE {' AND '.join(conditions)}"
            ])
        
        # 确保schema包含所有用到的属性
        for condition in conditions:
            attr_name = condition.split()[0]
            if attr_name not in schema_attrs:
                schema_attrs.append(attr_name)
        
        # 生成完整SQL
        schema_sql = self.create_schema_sql(schema_attrs)
        query_sql = "\n".join(sql_parts) + ";"
        
        # 分析模板的图文属性组合
        template_filters = set(template["filters"])
        image_filters = template_filters & self.pure_image_attributes
        text_filters = template_filters & self.pure_text_attributes
        common_filters = template_filters & self.common_attributes
        
        header = f"-- Query {query_id} - {qtype}\n"
        header += f"-- Template: {template['name']}\n"
        header += f"-- Description: {template['description']}\n"
        header += f"-- Use Case: {template['use_case']}\n"
        header += f"-- Result Rows: {len(indices)}\n"
        header += f"-- Filter Composition:\n"
        header += f"--   ├─ Image Attributes: {', '.join(image_filters) if image_filters else 'None'}\n"
        header += f"--   ├─ Text Attributes: {', '.join(text_filters) if text_filters else 'None'}\n"
        header += f"--   └─ Common Attributes: {', '.join(common_filters) if common_filters else 'None'}\n"
        header += f"-- Filters Applied: {len(conditions)}/{len(template['filters'])} (EXACT MATCH REQUIRED)"
        if relaxation_used > 0:
            header += f" (Values relaxed {relaxation_used} times)"
        header += "\n\n"
        
        complete_sql = header + schema_sql + "\n\n" + query_sql + "\n\n" + "-" * 60 + "\n\n"
        
        return complete_sql

def validate_all_templates():
    """验证所有模板是否符合图文结合要求"""
    pure_text_attributes = {"Nationality", "Birth_date", "Death_date", "Age", "Century", 
                           "Zodiac", "Birth_country", "Birth_city", "Birth_continent", 
                           "Death_country", "Death_city", "Field", "Marriage", 
                           "Art_institution", "Teaching", "Awards"}
    
    pure_image_attributes = { "Style", "Image_genre", "Color", "Tone", "Composition"}
    
    valid_templates = 0
    total_templates = len(TEMPLATES)
    
    for template in TEMPLATES:
        filters = set(template["filters"])
        filter_count = len(filters)
        
        if filter_count >= 2:
            has_image_attr = bool(filters & pure_image_attributes)
            has_text_attr = bool(filters & pure_text_attributes)
            if has_image_attr and has_text_attr:
                valid_templates += 1
        else:
            valid_templates += 1
    
    return valid_templates == total_templates

def generate_all_sql_files():
    """生成所有SQL文件，按类别和类型分文件夹"""
    
    if not validate_all_templates():
        return
    
    base_dir = "./Image_Text_Combined_Filters/"
    os.makedirs(base_dir, exist_ok=True)
    
    generator = SQLGenerator()
    queries_per_type = 5
    
    total_generated = 0
    
    for template in tqdm(TEMPLATES, desc="生成SQL"):
        template_dir = os.path.join(base_dir, template["name"])
        os.makedirs(template_dir, exist_ok=True)
        
        template_count = 0
        
        for qtype in QUERY_TYPES:
            sql_content = []
            query_id = 1
            
            generated = 0
            attempts = 0
            max_attempts = queries_per_type * 5
            
            while generated < queries_per_type and attempts < max_attempts:
                attempts += 1
                
                sql = generator.generate_query_sql(template, qtype, query_id)
                if sql:
                    sql_content.append(sql)
                    generated += 1
                    query_id += 1
            
            template_count += generated
            
            filename = os.path.join(template_dir, f"{qtype}.sql")
            with open(filename, 'w', encoding='utf-8') as f:
                f.write(f"-- {template['description']} - {qtype} 查询集合\n")
                f.write(f"-- 模板: {template['name']}\n")
                f.write(f"-- Filter数量: {len(template['filters'])}\n")
                f.write(f"-- 用途: {template['use_case']}\n")
                f.write("-- " + "=" * 60 + "\n\n")
                
                if sql_content:
                    f.write("".join(sql_content))
                else:
                    f.write("-- 注意: 未能生成有效查询\n")
        
        total_generated += template_count
    
    print(f"完成! 共生成 {total_generated} 个查询，保存在: {base_dir}")

# 执行生成
generate_all_sql_files()