In [None]:
import re
import os
import json
import pickle
import random
import openpyxl
import itertools
import pandas as pd
from datetime import datetime
from tqdm.autonotebook import tqdm
from collections import defaultdict

#### 定义输入输出路径，并加载数据
- `dataset_dirs`：三个医学数据集表格路径
- `statistics_output_dir`：统计数据表输出路径，包括属性值、选择率、基数
- `valid_where_output_dir`：所有有效谓词组合的输出路径
- **注意**：支持 .csv 格式，自动处理编码问题

In [None]:
# 数据集路径配置
dataset_dirs = {
    "disease": r"/data2/liujinqi/Benchmark/Query/Medical_Autoconstruct/disease.csv",
    "drug": r"/data2/liujinqi/Benchmark/Query/Medical_Autoconstruct/drug.csv", 
    "institutes": r"/data2/liujinqi/Benchmark/Query/Medical_Autoconstruct/institutes.csv"
}

statistics_output_dir = r"/data2/liujinqi/Benchmark/Query/Medical_Autoconstruct/medical_statistics.csv"
valid_where_output_dir = r"/data2/liujinqi/Benchmark/Query/Medical_Autoconstruct/valid_WHERE.json"

# 表之间的关系定义
table_relationships = [
    ("disease", "drug", "disease_name", "disease_name"),
    ("drug", "institutes", "manufacturer", "institution_name"),
    ("disease", "institutes", "disease_name", "research_diseases")  # 多值匹配
]

# 加载数据函数
def load_tables():
    tables = {}
    for table_name, file_path in dataset_dirs.items():
        try:
            tables[table_name] = pd.read_csv(file_path, encoding='utf-8')
        except:
            tables[table_name] = pd.read_csv(file_path, encoding='cp1252')
        print(f"Loaded {table_name}: {len(tables[table_name])} rows")
    return tables

# 加载数据
tables = load_tables()


#### 定义属性，方便后面按不同属性类型设计不同的构造方法
- `attr_desc_dict`：全部属性的集合，以及对应的自然语言描述
- `disease_attributes`、`drug_attributes`、`institute_attributes`：各表特有属性
- `non_numerical_attr`：非数值属性的集合
- `numerical_attr`：数值属性的集合（establishment_year, number_of_staff）
- `category_attr`：固定类别的属性
- `multi_value_attributes`：多值的属性，用"||"分隔或者逗号分隔


In [None]:
attr_desc_dict = {
    # Disease表属性
    "disease_name": "", "disease_type": "", "pathogenesis": "", "etiology": "",
    "diagnostic_methods": "", "common_symptoms": "", "complications": "", "affected_organs": "",
    "treatments": "", "drugs": "", "prognosis": "", "sequelae": "", "epidemiology": "",
    "risk_factors": "", "preventive_measures": "", "diagnosis_challenges": "",
    "treatment_challenges": "", "quality_of_life_impact": "",
    
    # Drug表属性
    "generic_name": "", "brand_name": "", "indication": "", "active_ingredients": "",
    "pharmaceutical_form": "", "manufacturer": "", "administration_route": "",
    "recommended_usage": "", "single_dose": "", "dosage_frequency": "",
    "mechanism_of_action": "", "side_effects": "", "activation_conditions": "",
    "prescription_status": "", "unsuitable_population": "", "storage_conditions": "",
    
    # Institute表属性
    "institution_name": "", "institution_type": "", "parent_organization": "", "leadership": "",
    "institution_country": "", "institution_city": "", "research_diseases": "",
    "research_fields": "", "key_technologies": "", "key_achievements": "",
    "international_collaboration": "", "funding_sources": "", "technology_application": "",
    "ID": ""
}

disease_attributes = [
    "disease_name", "disease_type", "pathogenesis", "etiology", "diagnostic_methods",
    "common_symptoms", "complications", "affected_organs", "treatments", "drugs",
    "prognosis", "sequelae", "epidemiology", "risk_factors", "preventive_measures",
    "diagnosis_challenges", "treatment_challenges", "quality_of_life_impact",
]

drug_attributes = [
    "generic_name", "brand_name", "disease_name", "indication", "active_ingredients",
    "pharmaceutical_form", "manufacturer", "administration_route", "recommended_usage",
    "single_dose", "dosage_frequency", "mechanism_of_action", "side_effects",
    "activation_conditions", "prescription_status", "unsuitable_population",
    "storage_conditions",
]

institute_attributes = [
    "institution_name", "institution_type", "parent_organization", "leadership", "institution_country", "institution_city",
    "research_diseases", "research_fields", "key_technologies", "key_achievements",
    "international_collaboration", "funding_sources", "technology_application", 
]

table_attributes = {
    "disease": {attr: attr_desc_dict[attr] for attr in disease_attributes},
    "drug": {attr: attr_desc_dict[attr] for attr in drug_attributes},
    "institutes": {attr: attr_desc_dict[attr] for attr in institute_attributes}
}

non_numerical_attr_list = [
    "disease_name", "disease_type", "pathogenesis", "etiology", "diagnostic_methods",
    "common_symptoms", "complications", "affected_organs", "treatments", "drugs",
    "prognosis", "sequelae", "epidemiology", "risk_factors", "preventive_measures", 
    "diagnosis_challenges", "treatment_challenges", "quality_of_life_impact",
    "generic_name", "brand_name", "disease_name", "indication", "active_ingredients",
    "pharmaceutical_form", "manufacturer", "administration_route", "recommended_usage",
    "single_dose", "dosage_frequency", "mechanism_of_action", "side_effects",
    "activation_conditions", "prescription_status", "unsuitable_population", 
    "storage_conditions", "institution_name", "institution_type", "parent_organization", 
    "leadership", "institution_country", "institution_city", "research_diseases", 
    "research_fields", "key_technologies", "key_achievements", "international_collaboration",
    "funding_sources", "technology_application",
]

numerical_attr_list = []

multi_value_attributes_list = [
    "disease_name", "disease_type", "pathogenesis", "etiology", "indication", "diagnostic_methods", 
    "common_symptoms", "active_ingredients", "generic_name", "brand_name", "pharmaceutical_form",
    "manufacturer", "administration_route", "recommended_usage", "single_dose", "dosage_frequency", 
    "mechanism_of_action", "side_effects", "activation_conditions", "prescription_status", 
    "unsuitable_population", "storage_conditions",
    
    "institution_name", "institution_type", "parent_organization",  
    "leadership", "institution_country", "institution_city", "research_diseases",
    "research_fields", "key_technologies", "key_achievements", "international_collaboration", 
    "funding_sources", "technology_application",

    "generic_name", "brand_name", "disease_name", "indication", "active_ingredients", "pharmaceutical_form",
    "manufacturer", "administration_route", "recommended_usage", "single_dose", "dosage_frequency", 
    "mechanism_of_action", "side_effects", "activation_conditions", "prescription_status", "unsuitable_population",
    "storage_conditions"
]

category_attr_list = [
    "disease_type","pathogenesis","diagnostic_methods","treatments","prognosis",
    "risk_factors","preventive_measures","quality_of_life_impact","institution_type","pharmaceutical_form",
    "administration_route","recommended_usage","activation_conditions","prescription_status","storage_conditions"
]

# 按表分组
non_numerical_attr = {
    "disease": [attr for attr in disease_attributes if attr in non_numerical_attr_list],
    "drug": [attr for attr in drug_attributes if attr in non_numerical_attr_list],
    "institutes": [attr for attr in institute_attributes if attr in non_numerical_attr_list]
}

numerical_attr = {
    "disease": [attr for attr in disease_attributes if attr in numerical_attr_list],
    "drug": [attr for attr in drug_attributes if attr in numerical_attr_list],
    "institutes": [attr for attr in institute_attributes if attr in numerical_attr_list]
}

multi_value_attributes = {
    "disease": [attr for attr in disease_attributes if attr in multi_value_attributes_list],
    "drug": [attr for attr in drug_attributes if attr in multi_value_attributes_list],
    "institutes": [attr for attr in institute_attributes if attr in multi_value_attributes_list]
}

category_attr = {
    "disease": [attr for attr in disease_attributes if attr in category_attr_list],
    "drug": [attr for attr in drug_attributes if attr in category_attr_list],
    "institutes": [attr for attr in institute_attributes if attr in category_attr_list]
}

#### 生成统计信息
- 属性 | 属性值 | 选择率 | 基数
- 用于后续构造 Filter
- 输出CSV格式，类似于原始Wikiart统计表

In [None]:


def generate_medical_statistics(tables):
    combined_statistics = pd.DataFrame()
    
    for table_name, df in tqdm(tables.items(), desc="生成统计信息"):
        for column in tqdm(table_attributes[table_name].keys(), desc=f"处理{table_name}属性", leave=False):
            if column not in df.columns:
                continue
            
            # 检查是否为多值属性
            if column in multi_value_attributes_list:
                # 处理多值属性：只统计拆分后的单个值
                all_values = []
                total_rows = len(df)
                
                for cell in df[column].dropna():
                    cell_str = str(cell)
                    if '||' in cell_str:
                        # 拆分||分隔的值
                        split_values = [v.strip() for v in cell_str.split('||') if v.strip()]
                        all_values.extend(split_values)
                    elif ',' in cell_str:
                        # 拆分逗号分隔的值
                        split_values = [v.strip() for v in cell_str.split(',') if v.strip()]
                        all_values.extend(split_values)
                    else:
                        all_values.append(cell_str.strip())
                
                # 统计拆分后的单个值
                value_counts = pd.Series(all_values).value_counts()
                # 计算selectivity：每个值出现的次数 / 总行数
                selectivities = (value_counts / total_rows).round(3)
            else:
                # 普通属性的处理
                value_counts = df[column].value_counts()
                selectivities = df[column].value_counts(normalize=True).round(3)
            
            null_count = df[column].isnull().sum()
            
            column_stats = pd.DataFrame({
                f"{table_name}.{column}": list(value_counts.index) + ["(null)"],
                'Count': list(value_counts.values) + [null_count],
                'Selectivity': list(selectivities.values) + [round(null_count / len(df), 3)]
            })
            
            if combined_statistics.empty:
                combined_statistics = column_stats
            else:
                combined_statistics = pd.concat([combined_statistics, column_stats], axis=1)
    
    return combined_statistics

medical_statistics = generate_medical_statistics(tables)
medical_statistics.to_csv(statistics_output_dir, index=False, encoding='utf-8')

#### 定义查询构造参数

In [None]:
max_filters = 5
min_rows = 5
max_select = 4
limit_list = [1, 2, 5, 10, 20, 50]

# 各类查询数量
sample_sf = 20         # SELECT FROM (单表)
sample_sfj = 30        # SELECT FROM JOIN
sample_sfw = 30        # SELECT FROM WHERE (单表)
sample_sfwj = 50       # SELECT FROM WHERE JOIN
sample_sfwt = 20       # SELECT FROM WHERE TOP-K (单表)
sample_sfwtj = 30      # SELECT FROM WHERE TOP-K JOIN
sample_sfwg = 20       # SELECT FROM WHERE GROUP BY (单表)
sample_sfwgj = 30      # SELECT FROM WHERE GROUP BY JOIN
sample_sfwa = 20       # SELECT FROM WHERE AGGREGATION (单表)
sample_sfwaj = 30      # SELECT FROM WHERE AGGREGATION JOIN
sample_sfag = 20       # SELECT FROM AGGREGATION GROUP BY (单表)
sample_sfagj = 30      # SELECT FROM AGGREGATION GROUP BY JOIN
sample_sfwga = 30      # SELECT FROM WHERE GROUP BY AGGREGATION (单表)
sample_sfwgaj = 40     # SELECT FROM WHERE GROUP BY AGGREGATION JOIN
sample_sfwgat = 20     # SELECT FROM WHERE GROUP BY AGGREGATION TOP-K (单表)
sample_sfwgatj = 30    # SELECT FROM WHERE GROUP BY AGGREGATION TOP-K JOIN

#### 定义Filter执行方法

In [None]:
def non_numerical_equal_to(value, condition):
    try:
        return str(value).lower().strip() == str(condition).lower().strip()
    except:
        return False

def non_numerical_equal_to_with_split(cell, condition):
    if pd.isna(cell):
        return False
    cell_str = str(cell)
    if '||' in cell_str:
        return condition in cell_str.split('||')
    elif ',' in cell_str:
        return condition in [v.strip() for v in cell_str.split(',')]
    return non_numerical_equal_to(cell, condition)

def number_greater_than(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) > float(condition)
    except:
        return False

def number_less_than(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) < float(condition)
    except:
        return False

def number_equal_to(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) == float(condition)
    except:
        return False

def number_greater_equal(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) >= float(condition)
    except:
        return False

def number_less_equal(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) <= float(condition)
    except:
        return False

def non_numerical_equal_to_with_split(cell, condition):
    if pd.isna(cell):
        return False
    cell_str = str(cell)
    if '||' in cell_str:
        # 修正：添加strip()处理，确保正确匹配
        return condition in [v.strip() for v in cell_str.split('||')]
    elif ',' in cell_str:
        return condition in [v.strip() for v in cell_str.split(',')]
    return non_numerical_equal_to(cell, condition)

def validate_join_relationships(tables):
    """验证表之间是否存在有效的JOIN关系"""
    valid_pairs = {
        ("disease", "drug"): "disease_name",
        ("drug", "institutes"): "manufacturer -> institution_name", 
        ("disease", "institutes"): "disease_name in research_diseases"
    }
    
    if len(tables) <= 1:
        return True, "Single table, no JOIN needed"
    
    for i in range(len(tables)):
        for j in range(i+1, len(tables)):
            table1, table2 = sorted([tables[i], tables[j]])
            pair = (table1, table2)
            
            if pair not in valid_pairs:
                return False, f"No valid JOIN relationship between {table1} and {table2}"
    
    return True, "All JOIN relationships are valid"


#### 汇总可取属性值

In [None]:
def get_sample_values(tables):
    attr_value_dict = {}
    
    for table_name, df in tqdm(tables.items(), desc="获取属性值"):
        attr_value_dict[table_name] = {}
        for attr in tqdm(table_attributes[table_name].keys(), desc=f"处理{table_name}属性值", leave=False):
            if attr in df.columns:
                non_null_values = df[attr].dropna()
                
                if attr in numerical_attr_list:
                    try:
                        numeric_values = pd.to_numeric(non_null_values, errors='coerce').dropna()
                        if len(numeric_values) > 0:
                            quantiles = [0.1, 0.25, 0.5, 0.75, 0.9]
                            sample_values = [numeric_values.quantile(q) for q in quantiles]
                            sample_values.extend(numeric_values.unique()[:10])
                            attr_value_dict[table_name][attr] = list(set(sample_values))[:15]
                    except:
                        attr_value_dict[table_name][attr] = []
                else:
                    # 检查是否为多值属性
                    if attr in multi_value_attributes_list:
                        # 处理多值属性：只提取拆分后的单个值
                        all_values = []
                        for cell in non_null_values:
                            cell_str = str(cell)
                            if '||' in cell_str:
                                # 拆分||分隔的值
                                split_values = [v.strip() for v in cell_str.split('||') if v.strip()]
                                all_values.extend(split_values)
                            elif ',' in cell_str:
                                # 拆分逗号分隔的值
                                split_values = [v.strip() for v in cell_str.split(',') if v.strip()]
                                all_values.extend(split_values)
                            else:
                                all_values.append(cell_str.strip())
                        
                        # 去重并取前20个单个值
                        unique_values = list(set(all_values))[:20]
                        attr_value_dict[table_name][attr] = unique_values
                    else:
                        # 普通属性的处理
                        unique_values = non_null_values.unique()[:20]
                        attr_value_dict[table_name][attr] = list(unique_values)
    
    return attr_value_dict

attr_value_dict = get_sample_values(tables)

#### 枚举所有可能的Filter

In [None]:
def build_filter_dict(tables, attr_value_dict):
    filter_dict = {}
    
    for table_name, df in tqdm(tables.items(), desc="处理表"):
        filter_dict[table_name] = {}
        
        for attr in tqdm(table_attributes[table_name].keys(), desc=f"处理{table_name}属性", leave=False):
            if attr not in df.columns or attr not in attr_value_dict[table_name]:
                continue
                
            condition_dict = {}
            
            if attr in non_numerical_attr[table_name]:
                # 确保只使用拆分后的单个值，过滤掉任何包含||的值
                clean_values = []
                for value in attr_value_dict[table_name][attr][:10]:
                    value_str = str(value).strip()
                    # 跳过包含||的值，确保只使用单个值
                    if '||' not in value_str and value_str:
                        clean_values.append(value_str)
                
                for possible_value in clean_values:
                    if attr in multi_value_attributes.get(table_name, []):
                        result = df[attr].apply(non_numerical_equal_to_with_split, condition=possible_value)
                    else:
                        result = df[attr].apply(non_numerical_equal_to, condition=possible_value)
                    
                    result_indices = df[result].index.tolist()
                    if len(result_indices) >= min_rows:
                        condition_dict[f"=='{possible_value}'"] = result_indices
            
            elif attr in numerical_attr[table_name]:
                for possible_value in attr_value_dict[table_name][attr][:8]:
                    try:
                        possible_value = float(possible_value)
                        
                        operations = [
                            (number_greater_than, f">{possible_value}"),
                            (number_less_than, f"<{possible_value}"),
                            (number_equal_to, f"=={possible_value}"),
                            (number_greater_equal, f">={possible_value}"),
                            (number_less_equal, f"<={possible_value}")
                        ]
                        
                        for operation_func, operation_str in operations:
                            result = df[attr].apply(operation_func, condition=possible_value)
                            result_indices = df[result].index.tolist()
                            if len(result_indices) >= min_rows:
                                condition_dict[operation_str] = result_indices
                    except:
                        continue
            
            filter_dict[table_name][attr] = condition_dict
    
    return filter_dict

filter_dict = build_filter_dict(tables, attr_value_dict)

with open("./filter_dict_multi_table.json", 'w', encoding='utf-8') as f:
    json.dump(filter_dict, f, ensure_ascii=False)

#### 均匀采样函数

In [None]:
def balanced_sample(filters, sample_size=80, random_seed=None):
    if random_seed is not None:
        random.seed(random_seed)
    
    if len(filters) <= sample_size:
        return filters
    
    table_to_filters = defaultdict(list)
    for filter_item in filters:
        table = filter_item[0]
        table_to_filters[table].append(filter_item)
    
    tables = list(table_to_filters.keys())
    per_table = sample_size // len(tables)
    remainder = sample_size % len(tables)
    
    sampled_filters = []
    for i, table in enumerate(tables):
        table_filters = table_to_filters[table]
        table_sample_size = per_table + (1 if i < remainder else 0)
        table_sample_size = min(table_sample_size, len(table_filters))
        
        sampled = random.sample(table_filters, table_sample_size)
        sampled_filters.extend(sampled)
    
    return sampled_filters

#### 构建WHERE条件组合

In [None]:
def build_where_combinations(filter_dict, max_combinations=1000):     
    all_filters = []          
    for table_name, table_filters in filter_dict.items():         
        for attr, conditions in table_filters.items():             
            for cond, indices in conditions.items():                 
                if len(indices) >= min_rows:                     
                    all_filters.append((table_name, attr, cond, set(indices)))          

    # 增加采样数量，确保有足够的filter进行组合     
    sampled_filters = balanced_sample(all_filters, sample_size=80, random_seed=42)          

    valid_where = []     
    seen_expressions = set()          

    for n in tqdm(range(1, min(6, len(sampled_filters) + 1)), desc="生成Filter组合"):  # 修改为最多5个filter
        combinations = list(itertools.combinations(sampled_filters, n))         
        random.shuffle(combinations)                  

        # 按照2:3:3:1:1的比例分配不同数量的filter组合
        if n == 1:             
            max_combos_this_round = int(max_combinations * 2 / 10)  # 2/10
        elif n == 2:             
            max_combos_this_round = int(max_combinations * 3 / 10)  # 3/10
        elif n == 3:             
            max_combos_this_round = int(max_combinations * 3 / 10)  # 3/10
        elif n == 4:             
            max_combos_this_round = int(max_combinations * 1 / 10)  # 1/10
        elif n == 5:             
            max_combos_this_round = int(max_combinations * 1 / 10)  # 1/10
        else:
            max_combos_this_round = 0                      

        combo_count = 0         
        for combo in tqdm(combinations, desc=f"{n}个Filter组合", leave=False):             
            if combo_count >= max_combos_this_round:                 
                break                              

            for op in ['AND', 'OR']:                 
                if n == 1:                     
                    where_clause = f"{combo[0][1]}{combo[0][2]}"                     
                    result_indices = combo[0][3]                     
                    tables_involved = [combo[0][0]]                 
                else:                     
                    predicates = [f"{item[1]}{item[2]}" for item in combo]                     
                    where_clause = f" {op} ".join(predicates)                                          

                    # 改进多表和单表的逻辑处理                     
                    combo_tables = [item[0] for item in combo]                     
                    if len(set(combo_tables)) == 1:                         
                        # 同一个表的多个条件                         
                        if op == 'AND':                             
                            result_indices = set.intersection(*[item[3] for item in combo])                         
                        else:  # OR                             
                            result_indices = set.union(*[item[3] for item in combo])                     
                    else:                         
                        # 多表条件：简化处理，取第一个条件的结果                         
                        # 实际应该做JOIN处理，这里简化                         
                        result_indices = combo[0][3]                                          
                    
                    tables_involved = list(set(combo_tables))                                  

                if len(result_indices) >= min_rows and where_clause not in seen_expressions:                     
                    seen_expressions.add(where_clause)                                          

                    query_dict = {                         
                        "WHERE Indices": list(result_indices),                         
                        "WHERE Total Rows": len(result_indices),                         
                        "WHERE": where_clause,                         
                        "Tables": tables_involved,                         
                        "Combination": [[item[0], item[1], item[2]] for item in combo],                         
                        "Operator": op if n > 1 else "NONE",                         
                        "Filter Count": n  # 添加filter数量标识                     
                    }                     
                    valid_where.append(query_dict)                          

                combo_count += 1          

    return valid_where  

valid_where = build_where_combinations(filter_dict)  

with open(valid_where_output_dir, 'w', encoding='utf-8') as f:     
    json.dump(valid_where, f, ensure_ascii=False)

#### 定义SCHEMA创建函数

In [None]:
def create_schema(attr_desc_dict, query_attr_list):
    schema = {}
    for key in query_attr_list:
        if key in non_numerical_attr_list:
            schema[key] = ["VARCHAR(255)", attr_desc_dict.get(key, "")]
        elif key in numerical_attr_list:
            schema[key] = ["FLOAT", attr_desc_dict.get(key, "")]
        else:
            schema[key] = ["VARCHAR(255)", attr_desc_dict.get(key, "")]
    return schema

def create_multi_table_schema(tables, selected_attrs):
    schema = {}
    
    for table, attr in selected_attrs:
        column_name = f"{table}.{attr}"
        
        if attr in numerical_attr_list:
            schema[column_name] = ["FLOAT", f"{table} table {attr} field"]
        else:
            schema[column_name] = ["VARCHAR(255)", f"{table} table {attr} field"]
    
    return schema

# 通用的WHERE条件选择函数
def select_where_by_ratio(valid_where, total_needed, single_table_only=True, multi_table_ratio=0.3):
    """
    按比例选择WHERE条件
    multi_table_ratio: 多filter的比例 (默认30%单filter, 70%多filter)
    """
    if single_table_only:
        single_filter = [w for w in valid_where if w.get("Filter Count", 1) == 1 and len(w["Tables"]) == 1]
        multi_filter = [w for w in valid_where if w.get("Filter Count", 1) > 1 and len(w["Tables"]) == 1]
    else:
        single_filter = [w for w in valid_where if w.get("Filter Count", 1) == 1]
        multi_filter = [w for w in valid_where if w.get("Filter Count", 1) > 1]
    
    single_count = int(total_needed * (1 - multi_table_ratio))
    multi_count = total_needed - single_count
    
    selected = multi_filter[:multi_count] + single_filter[:single_count]
    return selected[:total_needed]

#### 构建SELECT FROM（单表）

In [None]:
def generate_select_from_queries(tables):
    """
    生成简单的SELECT FROM查询（单表，无WHERE条件）
    """
    queries = []
    
    # 为每个表生成查询
    table_names = list(tables.keys())
    queries_per_table = sample_sf // len(table_names)
    remainder = sample_sf % len(table_names)
    
    for i, table_name in enumerate(tqdm(table_names, desc="生成SF查询")):
        table_query_count = queries_per_table + (1 if i < remainder else 0)
        
        for _ in range(table_query_count):
            available_attrs = list(table_attributes[table_name].keys())
            
            # 随机选择属性数量（1到max_select之间）
            num_attrs = random.randint(1, min(max_select, len(available_attrs)))
            selected_attrs = random.sample(available_attrs, num_attrs)
            
            # 计算表的总行数
            total_rows = len(tables[table_name])
            
            query_dict = {
                "Type": "SF",
                "SELECT": selected_attrs,
                "FROM": [table_name],
                "WHERE Indices": list(range(total_rows)),  # 所有行的索引
                "WHERE Total Rows": total_rows,  # 表的总行数
                "WHERE": "",  # 无WHERE条件
                "Tables": [table_name],
                "Combination": [],  # 无filter组合
                "Operator": "NONE",
                "Filter Count": 0,
                "SCHEMA": create_schema(attr_desc_dict, selected_attrs)
            }
            
            queries.append(query_dict)
    
    return queries

sf_queries = generate_select_from_queries(tables)

#### 构建SELECT FROM JOIN查询（多表）

In [None]:
def generate_select_from_join_queries(tables):
    """
    生成SELECT FROM JOIN查询（多表，无WHERE条件）
    """
    queries = []
    
    # 定义可能的JOIN组合
    join_combinations = [
        ['disease', 'drug'], 
        ['drug', 'institutes'], 
        ['disease', 'institutes'],
        ['disease', 'drug', 'institutes']
    ]
    
    # 为每种JOIN组合生成查询
    queries_per_combo = sample_sfj // len(join_combinations)
    remainder = sample_sfj % len(join_combinations)
    
    for i, join_combo in enumerate(tqdm(join_combinations, desc="生成SFJ查询")):
        combo_query_count = queries_per_combo + (1 if i < remainder else 0)
        
        for _ in range(combo_query_count):
            # 从每个表中选择属性
            available_attrs = []
            for table in join_combo:
                # 每个表最多选择3个属性，避免SELECT子句过长
                table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:5]]
                available_attrs.extend(table_attrs)
            
            # 随机选择属性数量（2到max_select之间）
            num_attrs = random.randint(2, min(max_select, len(available_attrs)))
            selected_attrs = random.sample(available_attrs, num_attrs)
            select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
            
            # 估算JOIN结果行数：使用参与表中行数最小的表作为估算
            estimated_rows = min(len(tables[table]) for table in join_combo)
            
            query_dict = {
                "Type": "SFJ",
                "SELECT": select_clause,
                "FROM": join_combo,
                "JOIN_TYPE": "INNER JOIN",
                "WHERE Indices": [],  # JOIN查询的具体索引需要执行才能确定
                "WHERE Total Rows": estimated_rows,  # 使用估算的行数
                "WHERE": "",  # 无WHERE条件
                "Tables": join_combo,
                "Combination": [],  # 无filter组合
                "Operator": "NONE", 
                "Filter Count": 0,
                "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
            }
            
            queries.append(query_dict)
    
    return queries

sfj_queries = generate_select_from_join_queries(tables)

#### 构建SELECT FROM WHERE查询（单表）

In [None]:
def generate_select_from_where_queries(valid_where, tables):
    queries = []
    
    # 修改：按比例选择WHERE条件
    selected_where = select_where_by_ratio(valid_where, sample_sfw, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in tqdm(selected_where, desc="生成SFW查询"):
        table_name = where_dict["Tables"][0]
        
        available_attrs = list(table_attributes[table_name].keys())[:8]
        selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
        
        query_attr_list = selected_attrs.copy()
        for item in where_dict["Combination"]:
            if item[1] not in query_attr_list:
                query_attr_list.append(item[1])
        
        query_dict = where_dict.copy()
        query_dict.update({
            "Type": "SFW",
            "SELECT": selected_attrs,
            "FROM": [table_name],
            "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
        })
        
        queries.append(query_dict)
    
    return queries

sfw_queries = generate_select_from_where_queries(valid_where, tables)

#### 构建多表JOIN查询

In [None]:
def generate_join_queries(valid_where, tables):
    queries = []
    
    # 修改：按比例选择WHERE条件
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwj//2, single_table_only=False, multi_table_ratio=0.6)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwj//2, single_table_only=True, multi_table_ratio=0.8)
    
    # 多表WHERE条件的JOIN查询
    for where_dict in tqdm(multi_table_where, desc="生成多表JOIN查询"):
        tables_involved = where_dict["Tables"] if len(where_dict["Tables"]) > 1 else ['disease', 'drug']
        
        available_attrs = []
        for table in tables_involved:
            table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
            available_attrs.extend(table_attrs)
        
        selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
        select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
        
        query_dict = where_dict.copy()
        query_dict.update({
            "Type": "SFWJ",
            "SELECT": select_clause,
            "FROM": tables_involved,
            "JOIN_TYPE": "INNER JOIN",
            "SCHEMA": create_multi_table_schema(tables_involved, selected_attrs)
        })
        
        queries.append(query_dict)
    
    # 单表WHERE条件应用到JOIN查询
    join_combinations = [['disease', 'drug'], ['drug', 'institutes'], ['disease', 'drug', 'institutes']]
    
    for where_dict in tqdm(single_table_for_join[:10], desc="生成单表->JOIN查询", leave=False):
        for join_combo in join_combinations[:2]:
            available_attrs = []
            for table in join_combo:
                table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
                available_attrs.extend(table_attrs)
            
            selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
            select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWJ",
                "SELECT": select_clause,
                "FROM": join_combo,
                "JOIN_TYPE": "INNER JOIN",
                "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
            })
            
            queries.append(query_dict)
    
    return queries

join_queries = generate_join_queries(valid_where, tables)

#### 构建SELECT FROM WHERE TOP-K查询

In [None]:
def generate_topk_queries(valid_where, tables):
    queries = []
    order_options = ['ASC', 'DESC']
    
    # 修改：按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwt, single_table_only=True, multi_table_ratio=0.6)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        available_attrs = list(table_attributes[table_name].keys())
        numerical_attrs = [attr for attr in available_attrs if attr in numerical_attr_list]
        
        if numerical_attrs:
            selected_attrs = random.sample(available_attrs[:8], min(max_select, len(available_attrs[:8])))
            order_column = random.choice(numerical_attrs)
            order_type = random.choice(order_options)
            limit_value = random.choice(limit_list)
            
            query_attr_list = selected_attrs + [order_column]
            for item in where_dict["Combination"]:
                if item[1] not in query_attr_list:
                    query_attr_list.append(item[1])
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWT",
                "SELECT": selected_attrs,
                "FROM": [table_name],
                "ORDER BY": [order_column, order_type],
                "LIMIT": limit_value,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    # JOIN TOP-K查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwtj//2, single_table_only=False, multi_table_ratio=0.7)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwtj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['disease', 'drug'], ['drug', 'institutes'], ['disease', 'drug', 'institutes']]
    
    for where_dict in multi_table_where + single_table_for_join[:10]:
        for join_combo in join_combinations[:2]:
            available_attrs = []
            numerical_join_attrs = []
            
            for table in join_combo:
                table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
                available_attrs.extend(table_attrs)
                
                table_numerical = [(table, attr) for attr in table_attributes[table].keys() 
                                 if attr in numerical_attr_list]
                numerical_join_attrs.extend(table_numerical)
            
            if numerical_join_attrs:
                selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
                select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
                
                order_table, order_attr = random.choice(numerical_join_attrs)
                order_column = f"{order_table}.{order_attr}"
                order_type = random.choice(order_options)
                limit_value = random.choice(limit_list)
                
                query_dict = where_dict.copy()
                query_dict.update({
                    "Type": "SFWTJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "ORDER BY": [order_column, order_type],
                    "LIMIT": limit_value,
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                })
                
                queries.append(query_dict)
    
    return queries

topk_queries = generate_topk_queries(valid_where, tables)

#### 构建SELECT FROM WHERE GROUP BY查询

In [None]:
def generate_groupby_queries(valid_where, tables):
    queries = []
    
    # 修改：按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwg, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        category_attrs = category_attr.get(table_name, [])
        if category_attrs:
            groupby_columns = random.sample(category_attrs, 1)
            selected_attrs = random.sample(list(table_attributes[table_name].keys())[:8], 
                                         min(max_select, len(list(table_attributes[table_name].keys())[:8])))
            
            if groupby_columns[0] not in selected_attrs:
                selected_attrs = groupby_columns + selected_attrs[:max_select-1]
            
            query_attr_list = selected_attrs.copy()
            for item in where_dict["Combination"]:
                if item[1] not in query_attr_list:
                    query_attr_list.append(item[1])
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWG",
                "SELECT": selected_attrs,
                "FROM": [table_name],
                "GROUP BY": groupby_columns,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    # JOIN GROUP BY查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwgj//2, single_table_only=False, multi_table_ratio=0.6)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwgj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['disease', 'drug'], ['drug', 'institutes']]
    
    for where_dict in multi_table_where + single_table_for_join[:10]:
        for join_combo in join_combinations:
            join_category_attrs = []
            for table in join_combo:
                table_categories = [(table, attr) for attr in category_attr.get(table, [])]
                join_category_attrs.extend(table_categories)
            
            if join_category_attrs:
                available_attrs = []
                for table in join_combo:
                    table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
                    available_attrs.extend(table_attrs)
                
                groupby_table, groupby_attr = random.choice(join_category_attrs)
                groupby_column = f"{groupby_table}.{groupby_attr}"
                
                selected_attrs = random.sample(available_attrs, min(max_select-1, len(available_attrs)))
                selected_attrs.append((groupby_table, groupby_attr))
                select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
                
                query_dict = where_dict.copy()
                query_dict.update({
                    "Type": "SFWGJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "GROUP BY": [groupby_column],
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                })
                
                queries.append(query_dict)
    
    return queries

groupby_queries = generate_groupby_queries(valid_where, tables)


#### 构建SELECT FROM WHERE AGGREGATION查询

In [None]:
def generate_aggregation_queries(valid_where, tables):
    queries = []
    aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
    
    # 修改：按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwa, single_table_only=True, multi_table_ratio=0.6)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        numerical_attrs = numerical_attr.get(table_name, [])
        
        if numerical_attrs:
            agg_column = random.choice(numerical_attrs)
            function = random.choice(aggregation_functions[1:])
            numerical = True
        else:
            agg_column = "*"
            function = 'COUNT'
            numerical = False
        
        query_attr_list = [agg_column] if agg_column != "*" else []
        for item in where_dict["Combination"]:
            if item[1] not in query_attr_list:
                query_attr_list.append(item[1])
        
        query_dict = where_dict.copy()
        query_dict.update({
            "Type": "SFWA",
            "SELECT": [f"{function}({agg_column})"],
            "FROM": [table_name],
            "AGGREGATION": [agg_column],
            "AGGREGATION Function": function,
            "Numerical": numerical,
            "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
        })
        
        queries.append(query_dict)
    
    # JOIN AGGREGATION查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwaj//2, single_table_only=False, multi_table_ratio=0.7)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwaj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['disease', 'drug'], ['drug', 'institutes']]
    
    for where_dict in multi_table_where + single_table_for_join[:10]:
        for join_combo in join_combinations:
            join_numerical_attrs = []
            for table in join_combo:
                table_numerical = [(table, attr) for attr in numerical_attr.get(table, [])]
                join_numerical_attrs.extend(table_numerical)
            
            if join_numerical_attrs:
                agg_table, agg_attr = random.choice(join_numerical_attrs)
                agg_column = f"{agg_table}.{agg_attr}"
                function = random.choice(aggregation_functions[1:])
                numerical = True
            else:
                agg_column = "*"
                function = 'COUNT'
                numerical = False
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWAJ",
                "SELECT": [f"{function}({agg_column})"],
                "FROM": join_combo,
                "JOIN_TYPE": "INNER JOIN",
                "AGGREGATION": [agg_column],
                "AGGREGATION Function": function,
                "Numerical": numerical,
                "SCHEMA": create_multi_table_schema(join_combo, [(agg_table, agg_attr)] if agg_column != "*" else [])
            })
            
            queries.append(query_dict)
    
    return queries


aggregation_queries = generate_aggregation_queries(valid_where, tables)

#### 构建SELECT FROM AGGREGATION GROUP BY查询

In [None]:
def generate_aggregation_groupby_queries(tables):
    """
    生成SELECT FROM AGGREGATION GROUP BY查询（无WHERE条件）
    SFAG: 单表聚合分组查询
    SFAGJ: 多表JOIN聚合分组查询
    """
    queries = []
    aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
    
    # 1. 单表SFAG查询
    table_names = list(tables.keys())
    queries_per_table = sample_sfag // len(table_names)
    remainder = sample_sfag % len(table_names)
    
    for i, table_name in enumerate(tqdm(table_names, desc="生成SFAG查询")):
        table_query_count = queries_per_table + (1 if i < remainder else 0)
        
        for _ in range(table_query_count):
            category_attrs = category_attr.get(table_name, [])
            numerical_attrs = numerical_attr.get(table_name, [])
            
            if category_attrs:
                # 选择分组列
                groupby_columns = random.sample(category_attrs, 1)
                
                # 选择聚合列和函数
                if numerical_attrs and random.random() > 0.5:
                    agg_column = random.choice(numerical_attrs)
                    function = random.choice(aggregation_functions[1:])  # 排除COUNT
                    numerical = True
                else:
                    agg_column = "*"
                    function = 'COUNT'
                    numerical = False
                
                # 构建SELECT子句
                select_clause = groupby_columns + [f"{function}({agg_column})"]
                
                # 计算估算行数（分组后行数通常是总行数的一小部分）
                total_rows = len(tables[table_name])
                estimated_rows = max(1, total_rows // 5)  # 假设分组后大约1/5的行数
                
                query_dict = {
                    "Type": "SFAG",
                    "SELECT": select_clause,
                    "FROM": [table_name],
                    "GROUP BY": groupby_columns,
                    "AGGREGATION": [agg_column],
                    "AGGREGATION Function": function,
                    "Numerical": numerical,
                    "WHERE Indices": list(range(total_rows)),  # 无WHERE条件，包含所有行
                    "WHERE Total Rows": estimated_rows,  # 分组后的估算行数
                    "WHERE": "",  # 无WHERE条件
                    "Tables": [table_name],
                    "Combination": [],  # 无filter组合
                    "Operator": "NONE",
                    "Filter Count": 0,
                    "SCHEMA": create_schema(attr_desc_dict, groupby_columns + ([agg_column] if agg_column != "*" else []))
                }
                
                queries.append(query_dict)
    
    # 2. 多表SFAGJ查询
    join_combinations = [
        ['disease', 'drug'], 
        ['drug', 'institutes'], 
        ['disease', 'institutes'],
        ['disease', 'drug', 'institutes']
    ]
    
    queries_per_combo = sample_sfagj // len(join_combinations)
    remainder_j = sample_sfagj % len(join_combinations)
    
    for i, join_combo in enumerate(tqdm(join_combinations, desc="生成SFAGJ查询")):
        combo_query_count = queries_per_combo + (1 if i < remainder_j else 0)
        
        for _ in range(combo_query_count):
            # 收集所有表的分类属性和数值属性
            join_category_attrs = []
            join_numerical_attrs = []
            
            for table in join_combo:
                table_categories = [(table, attr) for attr in category_attr.get(table, [])]
                join_category_attrs.extend(table_categories)
                
                table_numerical = [(table, attr) for attr in numerical_attr.get(table, [])]
                join_numerical_attrs.extend(table_numerical)
            
            if join_category_attrs:
                # 选择分组列
                groupby_table, groupby_attr = random.choice(join_category_attrs)
                groupby_column = f"{groupby_table}.{groupby_attr}"
                
                # 选择聚合列和函数
                if join_numerical_attrs and random.random() > 0.5:
                    agg_table, agg_attr = random.choice(join_numerical_attrs)
                    agg_column = f"{agg_table}.{agg_attr}"
                    function = random.choice(aggregation_functions[1:])
                    numerical = True
                    selected_attrs = [(groupby_table, groupby_attr), (agg_table, agg_attr)]
                else:
                    agg_column = "*"
                    function = 'COUNT'
                    numerical = False
                    selected_attrs = [(groupby_table, groupby_attr)]
                
                # 构建SELECT子句
                select_clause = [groupby_column, f"{function}({agg_column})"]
                
                # 估算JOIN后的行数
                estimated_rows = max(1, min(len(tables[table]) for table in join_combo if table in tables) // 3)
                
                query_dict = {
                    "Type": "SFAGJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "GROUP BY": [groupby_column],
                    "AGGREGATION": [agg_column],
                    "AGGREGATION Function": function,
                    "Numerical": numerical,
                    "WHERE Indices": [],  # JOIN查询的具体索引需要执行才能确定
                    "WHERE Total Rows": estimated_rows,
                    "WHERE": "",  # 无WHERE条件
                    "Tables": join_combo,
                    "Combination": [],  # 无filter组合
                    "Operator": "NONE",
                    "Filter Count": 0,
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                }
                
                queries.append(query_dict)
    
    return queries

# 生成SFAG查询
sfag_queries = generate_aggregation_groupby_queries(tables)

#### 构建SELECT FROM WHERE GROUP BY AGGREGATION查询

In [None]:
def generate_groupby_aggregation_queries(valid_where, tables):
    queries = []
    aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
    
    # 修改：按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwga, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        category_attrs = category_attr.get(table_name, [])
        numerical_attrs = numerical_attr.get(table_name, [])
        
        if category_attrs:
            groupby_columns = random.sample(category_attrs, 1)
            
            if numerical_attrs:
                agg_column = random.choice(numerical_attrs)
                function = random.choice(aggregation_functions[1:])
                numerical = True
            else:
                agg_column = "*"
                function = 'COUNT'
                numerical = False
            
            select_clause = groupby_columns + [f"{function}({agg_column})"]
            
            query_attr_list = groupby_columns + ([agg_column] if agg_column != "*" else [])
            for item in where_dict["Combination"]:
                if item[1] not in query_attr_list:
                    query_attr_list.append(item[1])
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWGA",
                "SELECT": select_clause,
                "FROM": [table_name],
                "GROUP BY": groupby_columns,
                "AGGREGATION": [agg_column],
                "AGGREGATION Function": function,
                "Numerical": numerical,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    # JOIN GROUP BY AGGREGATION查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwgaj//2, single_table_only=False, multi_table_ratio=0.6)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwgaj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['disease', 'drug'], ['drug', 'institutes']]
    
    for where_dict in multi_table_where + single_table_for_join[:10]:
        for join_combo in join_combinations:
            join_category_attrs = []
            join_numerical_attrs = []
            
            for table in join_combo:
                table_categories = [(table, attr) for attr in category_attr.get(table, [])]
                join_category_attrs.extend(table_categories)
                
                table_numerical = [(table, attr) for attr in numerical_attr.get(table, [])]
                join_numerical_attrs.extend(table_numerical)
            
            if join_category_attrs:
                groupby_table, groupby_attr = random.choice(join_category_attrs)
                groupby_column = f"{groupby_table}.{groupby_attr}"
                
                if join_numerical_attrs:
                    agg_table, agg_attr = random.choice(join_numerical_attrs)
                    agg_column = f"{agg_table}.{agg_attr}"
                    function = random.choice(aggregation_functions[1:])
                    numerical = True
                    selected_attrs = [(groupby_table, groupby_attr), (agg_table, agg_attr)]
                else:
                    agg_column = "*"
                    function = 'COUNT'
                    numerical = False
                    selected_attrs = [(groupby_table, groupby_attr)]
                
                select_clause = [groupby_column, f"{function}({agg_column})"]
                
                query_dict = where_dict.copy()
                query_dict.update({
                    "Type": "SFWGAJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "GROUP BY": [groupby_column],
                    "AGGREGATION": [agg_column],
                    "AGGREGATION Function": function,
                    "Numerical": numerical,
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                })
                
                queries.append(query_dict)
    
    return queries


groupby_aggregation_queries = generate_groupby_aggregation_queries(valid_where, tables)

#### 构建SELECT FROM WHERE GROUP BY AGGREGATION TOP-K查询

In [None]:
def generate_groupby_aggregation_topk_queries(valid_where, tables):
    queries = []
    aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
    order_options = ['ASC', 'DESC']
    
    # 修改：按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwgat, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        category_attrs = category_attr.get(table_name, [])
        numerical_attrs = numerical_attr.get(table_name, [])
        
        if category_attrs:
            groupby_columns = random.sample(category_attrs, 1)
            
            if numerical_attrs and random.random() > 0.5:
                agg_column = random.choice(numerical_attrs)
                function = random.choice(aggregation_functions[1:])
                numerical = True
                order_column = f"{function}({agg_column})"
            else:
                agg_column = "*"
                function = 'COUNT'
                numerical = False
                order_column = f"COUNT(*)"
            
            order_type = random.choice(order_options)
            limit_value = random.choice(limit_list)
            
            select_clause = groupby_columns + [f"{function}({agg_column})"]
            
            query_attr_list = groupby_columns + ([agg_column] if agg_column != "*" else [])
            for item in where_dict["Combination"]:
                if item[1] not in query_attr_list:
                    query_attr_list.append(item[1])
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWGAT",
                "SELECT": select_clause,
                "FROM": [table_name],
                "GROUP BY": groupby_columns,
                "AGGREGATION": [agg_column],
                "AGGREGATION Function": function,
                "Numerical": numerical,
                "ORDER BY": [order_column, order_type],
                "LIMIT": limit_value,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    # JOIN GROUP BY AGGREGATION TOP-K查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwgatj//2, single_table_only=False, multi_table_ratio=0.6)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwgatj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['disease', 'drug'], ['drug', 'institutes']]
    
    for where_dict in multi_table_where + single_table_for_join[:8]:
        for join_combo in join_combinations:
            join_category_attrs = []
            join_numerical_attrs = []
            
            for table in join_combo:
                table_categories = [(table, attr) for attr in category_attr.get(table, [])]
                join_category_attrs.extend(table_categories)
                
                table_numerical = [(table, attr) for attr in numerical_attr.get(table, [])]
                join_numerical_attrs.extend(table_numerical)
            
            if join_category_attrs:
                groupby_table, groupby_attr = random.choice(join_category_attrs)
                groupby_column = f"{groupby_table}.{groupby_attr}"
                
                if join_numerical_attrs and random.random() > 0.5:
                    agg_table, agg_attr = random.choice(join_numerical_attrs)
                    agg_column = f"{agg_table}.{agg_attr}"
                    function = random.choice(aggregation_functions[1:])
                    numerical = True
                    selected_attrs = [(groupby_table, groupby_attr), (agg_table, agg_attr)]
                    order_column = f"{function}({agg_column})"
                else:
                    agg_column = "*"
                    function = 'COUNT'
                    numerical = False
                    selected_attrs = [(groupby_table, groupby_attr)]
                    order_column = f"COUNT(*)"
                
                order_type = random.choice(order_options)
                limit_value = random.choice(limit_list)
                
                select_clause = [groupby_column, f"{function}({agg_column})"]
                
                query_dict = where_dict.copy()
                query_dict.update({
                    "Type": "SFWGATJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "GROUP BY": [groupby_column],
                    "AGGREGATION": [agg_column],
                    "AGGREGATION Function": function,
                    "Numerical": numerical,
                    "ORDER BY": [order_column, order_type],
                    "LIMIT": limit_value,
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                })
                
                queries.append(query_dict)
    
    return queries

groupby_aggregation_topk_queries = generate_groupby_aggregation_topk_queries(valid_where, tables)

def generate_sql_query(query_dict):
    tables = query_dict.get("FROM", [])
    
    select_clause = "SELECT " + ", ".join(query_dict["SELECT"])
    
    if len(tables) == 1:
        from_clause = f"FROM {tables[0]}"
    else:
        from_clause = f"FROM {tables[0]}"
        for i in range(1, len(tables)):
            join_condition = None
            for t1, t2, key1, key2 in table_relationships:
                if (tables[i-1] == t1 and tables[i] == t2) or (tables[i-1] == t2 and tables[i] == t1):
                    if tables[i-1] == t1:
                        left_key, right_key = key1, key2
                        left_table, right_table = t1, t2
                    else:
                        left_key, right_key = key2, key1
                        left_table, right_table = t2, t1
                    join_condition = f"{left_table}.{left_key} = {right_table}.{right_key}"
                    break
            
            if join_condition:
                from_clause += f"\nINNER JOIN {tables[i]} ON {join_condition}"
    
    where_clause = ""
    if "WHERE" in query_dict:
        where_clause = f"WHERE {query_dict['WHERE']}"
    
    sql_parts = [select_clause, from_clause]
    if where_clause:
        sql_parts.append(where_clause)
    
    return "\n".join(sql_parts) + ";"

def generate_schema_sql(schema_dict, table_name="Medical_Table"):
    create_table = f"CREATE TABLE {table_name} (\n"
    columns = []
    
    for col_name, (data_type, description) in schema_dict.items():
        comment = f" COMMENT '{description}'" if description else ""
        columns.append(f"    {col_name} {data_type}{comment}")
    
    create_table += ",\n".join(columns)
    create_table += "\n);"
    
    return create_table


#### SQL生成

In [None]:
def generate_sql_query(query_dict):
    tables = query_dict.get("FROM", [])
    
    select_clause = "SELECT " + ", ".join(query_dict["SELECT"])
    
    if len(tables) == 1:
        from_clause = f"FROM {tables[0]}"
    else:
        from_clause = f"FROM {tables[0]}"
        for i in range(1, len(tables)):
            join_condition = None
            for t1, t2, key1, key2 in table_relationships:
                if (tables[i-1] == t1 and tables[i] == t2) or (tables[i-1] == t2 and tables[i] == t1):
                    if tables[i-1] == t1:
                        left_key, right_key = key1, key2
                        left_table, right_table = t1, t2
                    else:
                        left_key, right_key = key2, key1
                        left_table, right_table = t2, t1
                    join_condition = f"{left_table}.{left_key} = {right_table}.{right_key}"
                    break
            
            if join_condition:
                from_clause += f"\nINNER JOIN {tables[i]} ON {join_condition}"
    
    where_clause = ""
    if "WHERE" in query_dict:
        where_clause = f"WHERE {query_dict['WHERE']}"
    
    group_by_clause = ""
    if "GROUP BY" in query_dict:
        group_by_clause = f"GROUP BY {', '.join(query_dict['GROUP BY'])}"
    
    order_by_clause = ""
    if "ORDER BY" in query_dict:
        order_col, order_type = query_dict["ORDER BY"]
        order_by_clause = f"ORDER BY {order_col} {order_type}"
    
    limit_clause = ""
    if "LIMIT" in query_dict:
        limit_clause = f"LIMIT {query_dict['LIMIT']}"
    
    sql_parts = [select_clause, from_clause]
    if where_clause:
        sql_parts.append(where_clause)
    if group_by_clause:
        sql_parts.append(group_by_clause)
    if order_by_clause:
        sql_parts.append(order_by_clause)
    if limit_clause:
        sql_parts.append(limit_clause)
    
    return "\n".join(sql_parts) + ";"

#### 保存所有查询结果

In [None]:
import json
import os
from datetime import datetime

def generate_sql_query(query_dict):
    """生成SQL语句 - 修复版本"""
    tables = query_dict.get("FROM", ["medical_table"])
    select_fields = query_dict.get("SELECT", ["*"])
    
    select_clause = "SELECT " + ", ".join(select_fields)
    
    if len(tables) == 1:
        from_clause = f"FROM {tables[0]}"
    else:
        from_clause = f"FROM {tables[0]}"
        
        # 修复：使用正确的医学数据集JOIN关系
        for i in range(1, len(tables)):
            current_table = tables[i]
            prev_table = tables[i-1]
            
            # 根据预定义的表关系确定JOIN条件
            join_condition = None
            
            # disease -> drug: disease_name = disease_name
            if (prev_table == "disease" and current_table == "drug") or \
               (prev_table == "drug" and current_table == "disease"):
                join_condition = "disease.disease_name = drug.disease_name"
            
            # drug -> institutes: manufacturer = institution_name  
            elif (prev_table == "drug" and current_table == "institutes") or \
                 (prev_table == "institutes" and current_table == "drug"):
                join_condition = "drug.manufacturer = institutes.institution_name"
            
            # disease -> institutes: 多值匹配（research_diseases字段包含疾病名称）
            elif (prev_table == "disease" and current_table == "institutes") or \
                 (prev_table == "institutes" and current_table == "disease"):
                join_condition = "FIND_IN_SET(disease.disease_name, REPLACE(institutes.research_diseases, '||', ',')) > 0"
            
            # 三表JOIN的情况：disease -> drug -> institutes
            elif len(tables) == 3 and i == 2:
                if "disease" in tables and "drug" in tables and "institutes" in tables:
                    # 已经有disease-drug的JOIN，现在添加drug-institutes的JOIN
                    join_condition = "drug.manufacturer = institutes.institution_name"
            
            # 如果没有找到合适的JOIN条件，使用默认的业务逻辑
            if not join_condition:
                print(f"Warning: No specific join condition found for {prev_table} -> {current_table}")
                # 可以根据具体业务逻辑添加更多JOIN条件
                join_condition = f"{prev_table}.{prev_table}_id = {current_table}.{prev_table}_id"
            
            from_clause += f"\nINNER JOIN {current_table} ON {join_condition}"
    
    sql_parts = [select_clause, from_clause]
    
    if query_dict.get("WHERE"):
        sql_parts.append(f"WHERE {query_dict['WHERE']}")
    if query_dict.get("GROUP BY"):
        sql_parts.append(f"GROUP BY {', '.join(query_dict['GROUP BY'])}")
    if query_dict.get("ORDER BY"):
        order_col, order_type = query_dict["ORDER BY"]
        sql_parts.append(f"ORDER BY {order_col} {order_type}")
    if query_dict.get("LIMIT"):
        sql_parts.append(f"LIMIT {query_dict['LIMIT']}")
    
    return "\n".join(sql_parts) + ";"

def generate_schema_sql(schema_dict, table_name="Medical_Data"):
    """生成建表SQL"""
    columns = []
    for col_name, (data_type, description) in schema_dict.items():
        comment = f" COMMENT '{description}'" if description else ""
        columns.append(f"    {col_name} {data_type}{comment}")
    
    return f"CREATE TABLE {table_name} (\n" + ",\n".join(columns) + "\n);"

def save_queries_with_sql(all_queries, output_dir="./sql_queries/"):
    """保存查询到文件"""
    if not all_queries:
        print("No queries to save.")
        return 0
        
    os.makedirs(output_dir, exist_ok=True)
    
    # 按查询类型分组
    queries_by_type = {}
    for query in all_queries:
        query_type = query.get('Type', 'Unknown')
        if query_type not in queries_by_type:
            queries_by_type[query_type] = []
        queries_by_type[query_type].append(query)
    
    total_saved = 0
    
    for query_type, queries in queries_by_type.items():
        sql_content = ""
        
        for i, query in enumerate(queries):
            # 添加默认值
            if 'FROM' not in query:
                query['FROM'] = ['medical_table']
            if 'SELECT' not in query:
                query['SELECT'] = ['*']
            if 'SCHEMA' not in query:
                query['SCHEMA'] = {'id': ['INT', 'Primary key']}
            
            # 生成SQL
            table_name = query['FROM'][0] if len(query['FROM']) == 1 else "Medical_Data"
            schema_sql = generate_schema_sql(query["SCHEMA"], table_name)
            query_sql = generate_sql_query(query)
            
            # 组合SQL
            complete_sql = f"-- Query {i+1} ({query_type})\n"
            complete_sql += f"-- Rows: {query.get('WHERE Total Rows', 'N/A')}\n\n"
            complete_sql += schema_sql + "\n\n"
            complete_sql += query_sql + "\n" + "-" * 50 + "\n\n"
            
            sql_content += complete_sql
        
        # 保存文件
        with open(os.path.join(output_dir, f"{query_type}.sql"), 'w', encoding='utf-8') as f:
            f.write(sql_content)
        
        total_saved += len(queries)
    
    print(f"Saved {total_saved} queries to {output_dir}")
    return total_saved

# 收集所有查询变量
def collect_all_queries():
    """收集环境中的查询变量"""
    all_queries = []
    query_vars = ['sf_queries', 'sfj_queries', 'sfw_queries', 'join_queries', 
                  'topk_queries', 'groupby_queries', 'aggregation_queries',
                  'groupby_aggregation_queries', 'groupby_aggregation_topk_queries',
                  'sfag_queries'] 
    
    for var_name in query_vars:
        if var_name in globals() and globals()[var_name]:
            all_queries.extend(globals()[var_name])
    
    return all_queries

# 执行保存
try:
    if 'all_queries' in globals() and all_queries:
        save_queries_with_sql(all_queries)
    else:
        collected = collect_all_queries()
        if collected:
            save_queries_with_sql(collected)
        else:
            print("No queries found. Run query generation functions first.")
except Exception as e:
    print(f"Error: {e}")

### 固定Filter和Join

In [None]:
# ======================== 医学数据集固定模板查询生成模块 ========================
import random
import os
from datetime import datetime

MEDICAL_TEMPLATES = [
    # 1个Filter模板 (2/10 = 20%) - 必须包括JOIN
    {
        "name": "disease_join_analysis", 
        "filters": [{"table": "disease", "attr": "disease_type"}], 
        "tables": ["disease", "drug"], 
        "filter_count": 1
    },
    {
        "name": "drug_join_analysis", 
        "filters": [{"table": "drug", "attr": "pharmaceutical_form"}], 
        "tables": ["drug", "institutes"], 
        "filter_count": 1
    },
    
    # 2个Filter模板 (3/10 = 30%) - JOIN的两个表各有一个Filter
    {
        "name": "disease_drug_matching", 
        "filters": [
            {"table": "disease", "attr": "disease_type"}, 
            {"table": "drug", "attr": "pharmaceutical_form"}
        ], 
        "tables": ["disease", "drug"], 
        "filter_count": 2
    },
    {
        "name": "drug_institute_analysis", 
        "filters": [
            {"table": "drug", "attr": "prescription_status"}, 
            {"table": "institutes", "attr": "institution_type"}
        ], 
        "tables": ["drug", "institutes"], 
        "filter_count": 2
    },
    {
        "name": "disease_institute_research", 
        "filters": [
            {"table": "disease", "attr": "treatments"}, 
            {"table": "institutes", "attr": "research_fields"}
        ], 
        "tables": ["disease", "institutes"], 
        "filter_count": 2
    },
    
    # 3个Filter模板 (3/10 = 30%) - 每个表至少一个Filter
    {
        "name": "comprehensive_disease_drug", 
        "filters": [
            {"table": "disease", "attr": "disease_type"}, 
            {"table": "disease", "attr": "treatments"},
            {"table": "drug", "attr": "pharmaceutical_form"}
        ], 
        "tables": ["disease", "drug"], 
        "filter_count": 3
    },
    {
        "name": "drug_manufacturer_research", 
        "filters": [
            {"table": "drug", "attr": "administration_route"}, 
            {"table": "institutes", "attr": "institution_type"},
            {"table": "institutes", "attr": "institution_country"}
        ], 
        "tables": ["drug", "institutes"], 
        "filter_count": 3
    },
    {
        "name": "three_table_basic", 
        "filters": [
            {"table": "disease", "attr": "disease_type"}, 
            {"table": "drug", "attr": "pharmaceutical_form"},
            {"table": "institutes", "attr": "institution_type"}
        ], 
        "tables": ["disease", "drug", "institutes"], 
        "filter_count": 3
    },
    
    # 4个Filter模板 (1/10 = 10%) - 每个表至少一个Filter
    {
        "name": "detailed_medical_analysis", 
        "filters": [
            {"table": "disease", "attr": "disease_type"}, 
            {"table": "disease", "attr": "prognosis"},
            {"table": "drug", "attr": "pharmaceutical_form"},
            {"table": "institutes", "attr": "institution_type"}
        ], 
        "tables": ["disease", "drug", "institutes"], 
        "filter_count": 4
    },
    
    # 5个Filter模板 (1/10 = 10%) - 每个表至少一个Filter
    {
        "name": "comprehensive_medical_ecosystem", 
        "filters": [
            {"table": "disease", "attr": "disease_type"}, 
            {"table": "disease", "attr": "treatments"},
            {"table": "drug", "attr": "pharmaceutical_form"},
            {"table": "drug", "attr": "prescription_status"},
            {"table": "institutes", "attr": "institution_type"}
        ], 
        "tables": ["disease", "drug", "institutes"], 
        "filter_count": 5
    }
]

QUERY_TYPES = ["SF", "SFW", "SFWT", "SFWG", "SFWA", "SFWGA", "SFJ", "SFWJ", "SFWTJ", "SFWGJ", "SFWAJ", "SFWGAJ"]

# 模板描述映射
TEMPLATE_DESCRIPTIONS = {
    "disease_type_analysis": "疾病类型分析",
    "treatment_study": "治疗方法研究", 
    "drug_form_analysis": "药物剂型分析",
    "prescription_analysis": "处方状态分析",
    "institution_distribution": "医疗机构分布",
    "research_capacity": "科研能力评估",
    "disease_drug_matching": "疾病-药物匹配分析",
    "manufacturer_analysis": "制药企业分析", 
    "disease_research_focus": "疾病研究重点",
    "medical_ecosystem": "医疗生态系统分析",
    "treatment_pipeline": "治疗管线分析"
}

class MedicalSQLTemplateGenerator:
    def __init__(self, tables, attr_value_dict, filter_dict):
        self.tables = tables
        self.attr_value_dict = attr_value_dict
        self.filter_dict = filter_dict
        
        # 修复：正确定义表之间的JOIN关系
        self.join_relationships = {
            ("disease", "drug"): ("disease_name", "disease_name"),
            ("drug", "institutes"): ("manufacturer", "institution_name"),
            ("disease", "institutes"): ("disease_name", "research_diseases")  # 特殊处理
        }
        
        self.table_attrs = {
            "disease": ["disease_name", "disease_type", "pathogenesis", "treatments", "prognosis"],
            "drug": ["generic_name", "pharmaceutical_form", "manufacturer", "administration_route", "prescription_status"],
            "institutes": ["institution_name", "institution_type", "institution_country", "research_fields", "key_technologies"]
        }
        
        self.numerical_attrs = []
        self.category_attrs = ["disease_type", "pathogenesis", "treatments", "prognosis", "institution_type", 
                              "pharmaceutical_form", "administration_route", "prescription_status"]
    
    
    def validate_template_constraints(self, template):
        """验证模板是否满足JOIN和Filter分布约束"""
        filters = template["filters"]
        tables = template["tables"]
        filter_count = template["filter_count"]
        
        # 统计每个表的filter数量
        table_filter_count = {}
        for filter_def in filters:
            table_name = filter_def["table"]
            table_filter_count[table_name] = table_filter_count.get(table_name, 0) + 1
        
        # 验证约束条件
        if filter_count == 1:
            # 1个Filter必须包括JOIN
            if len(tables) < 2:
                return False, "1-filter template must include JOIN"
        
        elif filter_count == 2:
            # 2个Filter的JOIN：两个表各有一个Filter
            if len(tables) != 2:
                return False, "2-filter template must be 2-table JOIN"
            for table in tables:
                if table_filter_count.get(table, 0) < 1:
                    return False, f"2-filter template: table {table} must have at least 1 filter"
        
        elif filter_count >= 3:
            # 3个及以上Filter：每个JOIN表都至少有一个Filter
            for table in tables:
                if table_filter_count.get(table, 0) < 1:
                    return False, f"{filter_count}-filter template: table {table} must have at least 1 filter"
        
        return True, "Valid"
    
    def get_alternative_filter_values(self, table_name, attr, exclude_values=None):
        """获取指定属性的备选值"""
        if exclude_values is None:
            exclude_values = []
            
        if table_name not in self.attr_value_dict or attr not in self.attr_value_dict[table_name]:
            return []
            
        values = self.attr_value_dict[table_name][attr]
        clean_values = [str(v).strip() for v in values 
                       if '||' not in str(v) and ',' not in str(v) 
                       and str(v).strip() and str(v).strip() not in exclude_values]
        
        return clean_values[:10]  # 返回前10个备选值
    
    def test_filter_combination_with_values(self, filter_combination, relaxation_level=0):
        """测试特定Filter值组合的筛选结果数量"""
        conditions = []
        tables_involved = set()
        
        for filter_def in filter_combination:
            table_name = filter_def["table"]
            filter_attr = filter_def["attr"]
            filter_value = filter_def["value"]
            tables_involved.add(table_name)
            
            # 直接使用指定的值生成条件
            if filter_attr in ["research_diseases", "treatments", "drugs", "common_symptoms", "complications", "research_fields", "key_technologies"]:
                if relaxation_level >= 1:
                    condition = f"{table_name}.{filter_attr} LIKE '%{filter_value}%'"
                else:
                    condition = f"{table_name}.{filter_attr} = '{filter_value}'"
            else:
                condition = f"{table_name}.{filter_attr} = '{filter_value}'"
            
            conditions.append(condition)
        
        if not conditions:
            return 0, [], list(tables_involved)
        
        # 使用Filter dict进行更准确的估算
        estimated_rows = float('inf')
        for filter_def in filter_combination:
            table_name = filter_def["table"]
            filter_attr = filter_def["attr"]
            filter_value = filter_def["value"]
            
            # 检查这个具体的filter值在filter_dict中的结果
            if (table_name in self.filter_dict and 
                filter_attr in self.filter_dict[table_name]):
                
                # 寻找匹配的条件
                for condition_str, indices in self.filter_dict[table_name][filter_attr].items():
                    if filter_value in condition_str:
                        estimated_rows = min(estimated_rows, len(indices))
                        break
        
        # 如果没有找到精确匹配，使用保守估算
        if estimated_rows == float('inf'):
            if len(tables_involved) == 1:
                table_rows = len(self.tables.get(list(tables_involved)[0], []))
            else:
                table_rows = min(len(self.tables.get(t, [])) for t in tables_involved if t in self.tables)
            
            base_ratio = 0.3 ** len(conditions)
            relaxation_bonus = 1.3 ** relaxation_level
            estimated_rows = max(1, int(table_rows * base_ratio * relaxation_bonus))
        
        return estimated_rows, conditions, list(tables_involved)
    
    def apply_template_filters_with_value_replacement(self, template, min_required_rows=3, min_attempts=10, max_attempts=20):
        """应用模板Filter，支持动态值替换策略"""
        # 首先验证模板约束
        is_valid, validation_msg = self.validate_template_constraints(template)
        if not is_valid:
            return None, None, None, f"CONSTRAINT_VIOLATION: {validation_msg}"
        
        filters = template["filters"]
        tables_involved = template["tables"]
        
        if not tables_involved:
            return None, None, None, "NO_VALID_TABLES"
        
        # 先尝试原始的放宽策略
        initial_result = self._try_initial_approach(filters, tables_involved)
        
        # 如果初始结果满足要求，直接返回
        if initial_result and initial_result[2] >= min_required_rows:
            return initial_result
        
        # 如果初始结果小于3条，使用值替换策略（至少尝试min_attempts次）
        replacement_result = self._try_value_replacement_strategy(
            filters, tables_involved, min_required_rows, min_attempts, max_attempts
        )
        
        # 返回更好的结果
        if replacement_result:
            return replacement_result
        elif initial_result:
            return initial_result
        else:
            return None, None, None, "ALL_STRATEGIES_FAILED"
    
    def _try_initial_approach(self, filters, tables_involved):
        """尝试初始方法（简单的放宽策略）"""
        for relaxation_level in range(4):
            conditions = []
            success = True
            
            for filter_def in filters:
                table_name = filter_def["table"]
                filter_attr = filter_def["attr"]
                
                if table_name not in tables_involved:
                    success = False
                    break
                
                val, condition = self.get_filter_value(table_name, filter_attr, relaxation_level)
                if condition:
                    conditions.append(condition)
                else:
                    if relaxation_level < 3:
                        success = False
                        break
            
            if success and len(conditions) >= len(filters) // 2:
                if len(tables_involved) == 1:
                    table_rows = len(self.tables.get(tables_involved[0], []))
                else:
                    table_rows = min(len(self.tables.get(t, [])) for t in tables_involved if t in self.tables)
                
                base_ratio = 0.3 ** len(conditions)
                relaxation_bonus = 1.5 ** relaxation_level
                estimated_rows = max(1, int(table_rows * base_ratio * relaxation_bonus))
                
                return list(range(estimated_rows)), conditions, estimated_rows, f"INITIAL_RELAXATION_{relaxation_level}"
        
        return None
    
    def _try_value_replacement_strategy(self, filters, tables_involved, min_required_rows, min_attempts, max_attempts):
        """尝试值替换策略，确保至少尝试min_attempts次"""
        import itertools
        import random
        
        best_result = None
        attempt = 0
        successful_attempts = 0  # 真正成功（>=min_required_rows）的尝试次数
        all_attempts_results = []  # 记录所有尝试的结果，用于调试
        
        # 为每个Filter生成候选值
        filter_candidates = []
        for filter_def in filters:
            table_name = filter_def["table"]
            attr = filter_def["attr"]
            
            candidate_values = self.get_alternative_filter_values(table_name, attr)
            if not candidate_values:
                # 如果没有候选值，使用原始方法
                val, _ = self.get_filter_value(table_name, attr, 0)
                candidate_values = [val] if val else ["DEFAULT"]
            
            filter_candidates.append({
                "table": table_name,
                "attr": attr,
                "values": candidate_values[:8]  # 增加候选值数量以提供更多组合
            })
        
        # 生成所有可能的值组合
        value_combinations = list(itertools.product(
            *[fc["values"] for fc in filter_candidates]
        ))
        
        # 随机打乱顺序，避免总是使用相同的组合
        random.shuffle(value_combinations)
        
        # 如果组合数量不足，重复使用组合直到有足够的尝试机会
        if len(value_combinations) < max_attempts:
            # 重复组合列表，确保有足够的尝试机会
            extended_combinations = []
            for i in range(max_attempts):
                extended_combinations.append(value_combinations[i % len(value_combinations)])
            value_combinations = extended_combinations
        
        # 尝试不同的值组合
        for value_combo in value_combinations[:max_attempts]:
            attempt += 1
            
            # 构建当前尝试的Filter组合
            current_filters = []
            for i, filter_candidate in enumerate(filter_candidates):
                current_filters.append({
                    "table": filter_candidate["table"],
                    "attr": filter_candidate["attr"],
                    "value": value_combo[i]
                })
            
            # 测试当前组合
            estimated_rows, conditions, tables = self.test_filter_combination_with_values(current_filters)
            all_attempts_results.append(estimated_rows)  # 记录所有结果用于调试
            
            # 检查是否真正成功（>=min_required_rows）
            is_successful = estimated_rows >= min_required_rows
            if is_successful:
                successful_attempts += 1
                best_result = (list(range(estimated_rows)), conditions, estimated_rows, attempt, successful_attempts)
                
                # 只有在至少尝试了min_attempts次，并且找到了成功结果时才可能早期停止
                if attempt >= min_attempts:
                    break
            else:
                # 记录最佳结果（即使不满足要求）
                if best_result is None or estimated_rows > best_result[2]:
                    best_result = (list(range(estimated_rows)), conditions, estimated_rows, attempt, successful_attempts)
            
            # 强制至少尝试min_attempts次，除非已经找到了满足要求的结果
            if attempt < min_attempts:
                continue  # 继续尝试
        
        if best_result:
            indices, conditions, estimated_rows, attempts_used, final_successful_count = best_result
            
            # 生成详细的策略信息
            strategy_info = f"VALUE_REPLACEMENT_{attempts_used}_ATTEMPTS"
            
            if final_successful_count > 0:
                strategy_info += f"_SUCCESS_{final_successful_count}"
            else:
                strategy_info += f"_NOSUCCESS_BEST_{estimated_rows}"
                
            # 添加调试信息
            max_result = max(all_attempts_results) if all_attempts_results else 0
            min_result = min(all_attempts_results) if all_attempts_results else 0
            strategy_info += f"_RANGE_{min_result}TO{max_result}"
            
            return indices, conditions, estimated_rows, strategy_info
        else:
            return None, None, None, f"VALUE_REPLACEMENT_FAILED_AFTER_{attempt}_ATTEMPTS_MAX_{max(all_attempts_results) if all_attempts_results else 0}"
    
    def apply_template_filters(self, template, max_relaxation=3):
        """应用模板Filter，支持多级放宽策略和约束验证"""
        # 使用新的值替换策略，确保结果<3条时至少尝试10次
        return self.apply_template_filters_with_value_replacement(template, min_required_rows=3, min_attempts=40, max_attempts=50)
    
    def get_filter_value(self, table_name, attr, relaxation_level=0):
        """智能获取模板Filter值，支持多级放宽策略"""
        if table_name not in self.attr_value_dict or attr not in self.attr_value_dict[table_name]:
            return None, None
            
        values = self.attr_value_dict[table_name][attr]
        if not values:
            return None, None
        
        # 过滤掉包含分隔符的复合值，确保使用单一值
        clean_values = [str(v).strip() for v in values if '||' not in str(v) and ',' not in str(v) and str(v).strip()]
        if not clean_values:
            return None, None
        
        # 由于没有数值属性，只处理文本属性
        val = random.choice(clean_values)
        
        # 针对多值属性，使用包含匹配（更宽松）
        if attr in ["research_diseases", "treatments", "drugs", "common_symptoms", "complications", "research_fields", "key_technologies"]:
            if relaxation_level >= 1:
                return val, f"{table_name}.{attr} LIKE '%{val}%'"
            else:
                return val, f"{table_name}.{attr} = '{val}'"
        else:
            return val, f"{table_name}.{attr} = '{val}'"
    
     def generate_join_clause(self, tables):
        """生成正确的JOIN子句"""
        if len(tables) <= 1:
            return ""
        
        join_parts = []
        
        for i in range(1, len(tables)):
            current_table = tables[i]
            join_condition = None
            
            # 寻找当前表与之前任一表的JOIN关系
            for prev_idx in range(i):
                prev_table = tables[prev_idx]
                
                # 检查直接的JOIN关系
                for (t1, t2), (key1, key2) in self.join_relationships.items():
                    if (prev_table == t1 and current_table == t2):
                        if key2 == "research_diseases":
                            # 特殊处理：research_diseases是多值字段
                            join_condition = f"FIND_IN_SET({prev_table}.{key1}, REPLACE({current_table}.{key2}, '||', ',')) > 0"
                        else:
                            join_condition = f"{prev_table}.{key1} = {current_table}.{key2}"
                        break
                    elif (prev_table == t2 and current_table == t1):
                        if key1 == "research_diseases":
                            # 特殊处理：research_diseases是多值字段  
                            join_condition = f"FIND_IN_SET({current_table}.{key2}, REPLACE({prev_table}.{key1}, '||', ',')) > 0"
                        else:
                            join_condition = f"{prev_table}.{key2} = {current_table}.{key1}"
                        break
                
                if join_condition:
                    break
            
            # 如果没有找到直接关系，尝试间接关系
            if not join_condition and len(tables) == 3:
                # 三表JOIN的特殊处理：disease -> drug -> institutes
                if i == 2 and tables == ["disease", "drug", "institutes"]:
                    join_condition = "drug.manufacturer = institutes.institution_name"
                elif i == 2 and tables == ["disease", "institutes", "drug"]:
                    join_condition = "institutes.institution_name = drug.manufacturer"
            
            if join_condition:
                join_parts.append(f"INNER JOIN {current_table} ON {join_condition}")
            else:
                # 最后的备选方案：使用业务主键
                print(f"Warning: Using fallback join for {current_table}")
                join_parts.append(f"INNER JOIN {current_table} ON {tables[0]}.disease_name = {current_table}.disease_name")
        
        return "\n" + "\n".join(join_parts) if join_parts else ""
    
    def generate_query_sql(self, template, query_type, query_id):
        """生成医学查询SQL，支持失败处理和值替换策略"""
        # 应用模板Filter
        result = self.apply_template_filters(template)
        
        # 处理生成失败的情况
        if result[0] is None or result[1] is None:
            # 检查第四个元素是字符串还是其他类型
            error_info = result[3]
            if isinstance(error_info, str) and (
                error_info in ["ALL_STRATEGIES_FAILED", "NO_VALID_TABLES"] or 
                error_info.startswith("CONSTRAINT_VIOLATION") or
                error_info.startswith("VALUE_REPLACEMENT_FAILED")
            ):
                return None  # 简单跳过失败的查询
            else:
                return None  # 其他类型的失败也跳过
        
        indices, conditions, estimated_rows, strategy_info = result
        
        # 获取涉及的表
        tables = template["tables"]
        
        # 强制所有查询都是JOIN查询（根据用户要求）
        is_join = True
        
        # 选择显示列
        select_attrs = []
        for table in tables[:2]:
            table_attrs = self.table_attrs.get(table, [])[:3]
            for attr in table_attrs:
                select_attrs.append(f"{table}.{attr}")
        
        select_clause = f"SELECT {', '.join(select_attrs[:4])}"
        from_clause = f"FROM {tables[0]}" + self.generate_join_clause(tables)
        where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
        
        sql_parts = [select_clause, from_clause]
        if where_clause:
            sql_parts.append(where_clause)
        
        # 根据查询类型添加子句
        base_type = query_type.rstrip('J')
        
        if base_type == "SFWT":
            # 由于没有数值属性，使用字符串排序
            order_attrs = []
            for table in tables:
                for attr in self.table_attrs.get(table, []):
                    if attr not in ["ID"]:  # 排除ID字段
                        order_attrs.append(f"{table}.{attr}")
            
            if order_attrs:
                order_col = random.choice(order_attrs)
                sql_parts.append(f"ORDER BY {order_col} {random.choice(['ASC', 'DESC'])}")
                limit_value = random.choice([5, 10, 15])
                sql_parts.append(f"LIMIT {limit_value}")
                # 对于TOP-K查询，实际结果可能少于估算值
                estimated_rows = min(estimated_rows, limit_value)
        
        elif base_type in ["SFWG", "SFWGA"]:
            category_attrs = []
            for table in tables:
                for attr in self.table_attrs.get(table, []):
                    if attr in self.category_attrs:
                        category_attrs.append(f"{table}.{attr}")
            
            if category_attrs:
                group_col = random.choice(category_attrs)
                sql_parts.append(f"GROUP BY {group_col}")
                
                if base_type == "SFWGA":
                    sql_parts[0] = f"SELECT {group_col}, COUNT(*)"
                
                 # GROUP BY查询的结果行数通常会减少
                estimated_rows = max(1, estimated_rows // 3)
        
        elif base_type == "SFWA":
            sql_parts[0] = "SELECT COUNT(*)"
            # 聚合查询只返回一行
            estimated_rows = 1
        
        # 生成完整的查询字符串
        query_sql = "\n".join(sql_parts) + ";"
        
        # 生成Schema
        schema_parts = []
        for table in tables:
            attrs = self.table_attrs.get(table, [])
            columns = [f"    {table}_id INTEGER PRIMARY KEY"]
            
            for attr in attrs:
                # 由于没有数值属性，所有属性都是VARCHAR
                columns.append(f"    {attr} VARCHAR(255)")
            
            schema = f"CREATE TABLE {table} (\n" + ",\n".join(columns) + "\n);"
            schema_parts.append(schema)
        
        schema_sql = "\n\n".join(schema_parts)
        
        # 组合完整输出
        filter_info = ", ".join([f"{f['table']}.{f['attr']}" for f in template["filters"]])
        output = f"-- Query {query_id} - {query_type}\n"
        output += f"-- Tables: {', '.join(tables)}\n"
        output += f"-- Filters: {filter_info}\n"
        output += f"-- Filter Count: {template['filter_count']}\n"
        
        # 根据策略类型显示详细信息
        if isinstance(strategy_info, str):
            if strategy_info.startswith("VALUE_REPLACEMENT"):
                # 解析策略信息
                parts = strategy_info.split("_")
                attempts = parts[2] if len(parts) > 2 else "N/A"
                
                if "_SUCCESS_" in strategy_info:
                    success_count = strategy_info.split("_SUCCESS_")[1].split("_")[0]
                    range_info = ""
                    if "_RANGE_" in strategy_info:
                        range_part = strategy_info.split("_RANGE_")[1]
                        range_info = f" (range: {range_part.replace('TO', '-')})"
                    output += f"-- Estimated Result Rows: {estimated_rows} (Value replacement: {attempts} attempts, {success_count} successful{range_info})\n"
                elif "_NOSUCCESS_" in strategy_info:
                    best_result = strategy_info.split("_NOSUCCESS_BEST_")[1].split("_")[0] if "_NOSUCCESS_BEST_" in strategy_info else "N/A"
                    range_info = ""
                    if "_RANGE_" in strategy_info:
                        range_part = strategy_info.split("_RANGE_")[1]
                        range_info = f" (range: {range_part.replace('TO', '-')})"
                    output += f"-- Estimated Result Rows: {estimated_rows} (Value replacement: {attempts} attempts, 0 successful, best: {best_result}{range_info})\n"
                elif "_BEST_OF_" in strategy_info:
                    total_attempts = strategy_info.split("_BEST_OF_")[1].split("_")[0] if "_BEST_OF_" in strategy_info else attempts
                    output += f"-- Estimated Result Rows: {estimated_rows} (Value replacement: best of {total_attempts} attempts, 0 successful)\n"
                else:
                    output += f"-- Estimated Result Rows: {estimated_rows} (Value replacement: {attempts} attempts)\n"
            elif strategy_info.startswith("INITIAL_RELAXATION"):
                relaxation_level = strategy_info.split("_")[-1]
                output += f"-- Estimated Result Rows: {estimated_rows} (Initial approach: relaxation level {relaxation_level})\n"
            else:
                output += f"-- Estimated Result Rows: {estimated_rows} (Strategy: {strategy_info})\n"
        else:
            output += f"-- Estimated Result Rows: {estimated_rows}\n"
        
        output += "\n"
        output += schema_sql + "\n\n"
        output += query_sql + "\n\n"
        output += "-" * 40 + "\n\n"
        
        return output

def generate_medical_template_queries():
    base_dir = "./Medical_Template_Queries/"
    os.makedirs(base_dir, exist_ok=True)
    
    generator = MedicalSQLTemplateGenerator(tables, attr_value_dict, filter_dict)
    total_generated = 0
    
    # 为每个模板创建独立文件夹
    for template in MEDICAL_TEMPLATES:
        template_dir = os.path.join(base_dir, template["name"])
        os.makedirs(template_dir, exist_ok=True)
        
        is_multi_table = len(template.get("tables", [])) > 1
        query_types = ["SF", "SFW", "SFWT", "SFWG", "SFWA", "SFAG"] + (["SFJ", "SFWJ", "SFWTJ", "SFWGJ", "SFWAJ", "SFAGJ"] if is_multi_table else [])
        
        # 为每个查询类型生成独立文件
        for qtype in query_types:
            sql_content = []
            
            # 简化的文件头部
            header = f"-- {template['name']} - {qtype}\n"
            header += "-- " + "=" * 40 + "\n\n"
            sql_content.append(header)
            
            # 生成查询实例
            query_id = 1
            generated = 0
            attempts = 0
            
            while generated < 6 and attempts < 18:
                attempts += 1
                sql = generator.generate_query_sql(template, qtype, query_id)
                
                if sql and "[GENERATION FAILED]" not in sql:
                    sql_content.append(sql)
                    generated += 1
                    total_generated += 1
                    query_id += 1
            
            # 保存文件
            filename = os.path.join(template_dir, f"{qtype}.sql")
            with open(filename, 'w', encoding='utf-8') as f:
                f.write("".join(sql_content))
    
    return total_generated

# 执行生成
try:
    if 'tables' in globals() and 'attr_value_dict' in globals() and 'filter_dict' in globals():
        total_queries = generate_medical_template_queries()
        print(f"Generated {total_queries} medical queries in ./Medical_Template_Queries/")
    else:
        print("Required variables not found. Run data loading steps first.")
except Exception as e:
    print(f"Error: {e}")