In [None]:
import re
import os
import json
import pickle
import random
import openpyxl
import itertools
import pandas as pd
from datetime import datetime
from tqdm.autonotebook import tqdm
from collections import defaultdict

#### 定义输入输出路径，并加载数据
- `dataset_dirs`：三个NBA数据集表格路径
- `statistics_output_dir`：统计数据表输出路径，包括属性值、选择率、基数
- `valid_where_output_dir`：所有有效谓词组合的输出路径
- **注意**：支持 .csv 格式，自动处理编码问题

In [None]:
#### 定义输入输出路径，并加载数据
# dataset_dirs：四个表的文件路径
# statistics_output_dir：统计数据表输出路径，包括属性值、选择率、基数
# valid_where_output_dir：所有有效谓词组合的输出路径

# 数据集路径配置
dataset_dirs = {
    "city": "/data2/liujinqi/Benchmark/Query/NBA_AutoConstruct/JOIN/city.csv",
    "owner": "/data2/liujinqi/Benchmark/Query/NBA_AutoConstruct/JOIN/owner.csv", 
    "team": "/data2/liujinqi/Benchmark/Query/NBA_AutoConstruct/JOIN/team.csv",
    "player": "/data2/liujinqi/Benchmark/Query/NBA_AutoConstruct/JOIN/player.csv"
}

statistics_output_dir = r"/data2/liujinqi/Benchmark/Query/NBA_AutoConstruct/JOIN/NBA_Statistics.csv"
valid_where_output_dir = r"/data2/liujinqi/Benchmark/Query/NBA_AutoConstruct/JOIN/NBA_valid_WHERE.json"

# 读取所有表的数据，保持独立
dataframes = {}
for table_name, file_path in dataset_dirs.items():
    try:
        if table_name == "player":
            # player1.csv可能有编码问题，尝试不同编码
            try:
                dataframes[table_name] = pd.read_csv(file_path, encoding='utf-8')
            except:
                dataframes[table_name] = pd.read_csv(file_path, encoding='latin1')
        else:
            dataframes[table_name] = pd.read_csv(file_path, encoding='utf-8')
        
        # 清理player表的空列
        if table_name == "player":
            # 删除空列或无用列
            cols_to_drop = [col for col in dataframes[table_name].columns if col.startswith('_') or col == '']
            dataframes[table_name] = dataframes[table_name].drop(columns=cols_to_drop, errors='ignore')
            
        print(f"成功加载 {table_name}: {len(dataframes[table_name])} 行, {len(dataframes[table_name].columns)} 列")
        print(f"  列名: {list(dataframes[table_name].columns)}")
        
    except FileNotFoundError:
        print(f"警告: 文件不存在 {file_path}")
        dataframes[table_name] = pd.DataFrame()

# 定义表之间的JOIN关系
join_relationships = {
    # (表1, 表2): (表1的关联字段, 表2的关联字段)
    ("owner", "team"): ("NBA_team", "team_name"),
    ("team", "city"): ("location", "city_name"),
    ("player", "team"): ("team", "team_name")
}

# 定义表到表的路径（用于复杂JOIN）
table_paths = {
    ("owner", "city"): [("owner", "team"), ("team", "city")],
    ("player", "city"): [("player", "team"), ("team", "city")],
    ("owner", "player"): [("owner", "team"), ("team", "player")]
}

print("表关系定义完成:")
for (t1, t2), (f1, f2) in join_relationships.items():
    print(f"  {t1}.{f1} = {t2}.{f2}")

#### 定义属性，方便后面按不同属性类型设计不同的构造方法
- `attr_desc_dict`：全部属性的集合，以及对应的自然语言描述
- `players_attributes`、`teams_attributes`、`games_attributes`：各表特有属性
- `non_numerical_attr`：非数值属性的集合
- `numerical_attr`：数值属性的集合（年龄、得分、助攻等）
- `category_attr`：固定类别的属性
- `multi_value_attributes`：多值的属性，用"||"分隔或者逗号分隔

In [None]:
#### 定义属性，方便后面按不同属性类型设计不同的构造方法
# attr_desc_dict：全部属性的集合，以及对应的自然语言描述
# city_attributes、owner_attributes、team_attributes、player_attributes：各表特有属性
# non_numerical_attr：非数值属性的集合
# numerical_attr：数值属性的集合
# category_attr：固定类别的属性
# multi_value_attributes：多值的属性，用"||"分隔或者逗号分隔

attr_desc_dict = {
    # City表属性
    "city_name": "", "state_name": "", "population": "", "area": "", "gdp": "", "file": "",
    
    # Owner表属性  
    "name": "", "age": "", "nationality": "", "NBA_team": "", "own_year": "", "file": "",
    
    # Team表属性
    "team_name": "", "founded_year": "", "location": "", "ownership": "", "championship": "", "file": "",
    
    # Player表属性
    "name": "", "birth_date": "", "nationality": "", "age": "", "team": "", "position": "", 
    "draft_pick": "", "draft_year": "", "college": "", "nba_championships": "", "mvp_awards": "", 
    "olympic_gold_medals": "", "fiba_world_cup": "", "file": ""
}

# 根据实际数据结构定义各表属性
city_attributes = list(dataframes["city"].columns) if "city" in dataframes and not dataframes["city"].empty else []
owner_attributes = list(dataframes["owner"].columns) if "owner" in dataframes and not dataframes["owner"].empty else []
team_attributes = list(dataframes["team"].columns) if "team" in dataframes and not dataframes["team"].empty else []
player_attributes = list(dataframes["player"].columns) if "player" in dataframes and not dataframes["player"].empty else []

table_attributes = {
    "city": {attr: attr_desc_dict.get(attr, "") for attr in city_attributes},
    "owner": {attr: attr_desc_dict.get(attr, "") for attr in owner_attributes},
    "team": {attr: attr_desc_dict.get(attr, "") for attr in team_attributes},
    "player": {attr: attr_desc_dict.get(attr, "") for attr in player_attributes}
}

print("各表属性:")
for table_name, attrs in table_attributes.items():
    print(f"{table_name}: {list(attrs.keys())}")

non_numerical_attr_list = [
    "city_name", "state_name", "name", "nationality", "NBA_team", "team_name", 
    "location", "ownership", "birth_date", "position", "college", "file"
]

numerical_attr_list = [
    "population", "area", "gdp", "age", "own_year", "founded_year", "championship",
    "draft_pick", "draft_year", "nba_championships", "mvp_awards", 
    "olympic_gold_medals", "fiba_world_cup"
]

multi_value_attributes_list = [
    # 这个数据集中没有明显的多值属性，保留空列表
]

category_attr_list = [
    "state_name", "nationality", "team_name", "location", "position", "college"
]

nba_join_relationships = {
    # (表1, 表2): (表1的关联字段, 表2的关联字段)
    ("owner", "team"): ("team", "team_name"),           # owner的team字段关联team的team_name字段
    ("team", "city"): ("city", "city"),                 # team的city字段关联city的city字段  
    ("player", "team"): ("team", "team_name")           # player的team字段关联team的team_name字段
}

# 按表分组 - 基于实际列名
non_numerical_attr = {}
numerical_attr = {}
multi_value_attributes = {}
category_attr = {}

for table_name, attrs in table_attributes.items():
    non_numerical_attr[table_name] = [attr for attr in attrs.keys() if attr in non_numerical_attr_list]
    numerical_attr[table_name] = [attr for attr in attrs.keys() if attr in numerical_attr_list]
    multi_value_attributes[table_name] = [attr for attr in attrs.keys() if attr in multi_value_attributes_list]
    category_attr[table_name] = [attr for attr in attrs.keys() if attr in category_attr_list]

print("\n属性分类完成:")
for table_name in table_attributes.keys():
    print(f"{table_name}:")
    print(f"  非数值属性: {non_numerical_attr[table_name]}")
    print(f"  数值属性: {numerical_attr[table_name]}")
    print(f"  类别属性: {category_attr[table_name]}")

#### 生成统计信息
- 属性 | 属性值 | 选择率 | 基数
- 用于后续构造 Filter
- 输出CSV格式，类似于原始统计表

In [None]:
def generate_nba_statistics(dataframes):
    combined_statistics = pd.DataFrame()
    
    for table_name, df in tqdm(dataframes.items(), desc="生成统计信息"):
        if df.empty:
            continue
            
        for column in tqdm(table_attributes[table_name].keys(), desc=f"处理{table_name}属性", leave=False):
            if column not in df.columns:
                continue
            
            # 检查是否为多值属性
            if column in multi_value_attributes_list:
                # 处理多值属性：只统计拆分后的单个值
                all_values = []
                total_rows = len(df)
                
                for cell in df[column].dropna():
                    cell_str = str(cell)
                    if '||' in cell_str:
                        # 拆分||分隔的值
                        split_values = [v.strip() for v in cell_str.split('||') if v.strip()]
                        all_values.extend(split_values)
                    elif ',' in cell_str:
                        # 拆分逗号分隔的值
                        split_values = [v.strip() for v in cell_str.split(',') if v.strip()]
                        all_values.extend(split_values)
                    else:
                        all_values.append(cell_str.strip())
                
                # 统计拆分后的单个值
                value_counts = pd.Series(all_values).value_counts()
                # 计算selectivity：每个值出现的次数 / 总行数
                selectivities = (value_counts / total_rows).round(3)
            else:
                # 普通属性的处理
                value_counts = df[column].value_counts()
                selectivities = df[column].value_counts(normalize=True).round(3)
            
            null_count = df[column].isnull().sum()
            
            column_stats = pd.DataFrame({
                f"{table_name}.{column}": list(value_counts.index) + ["(null)"],
                'Count': list(value_counts.values) + [null_count],
                'Selectivity': list(selectivities.values) + [round(null_count / len(df), 3)]
            })
            
            if combined_statistics.empty:
                combined_statistics = column_stats
            else:
                combined_statistics = pd.concat([combined_statistics, column_stats], axis=1)
    
    return combined_statistics

nba_statistics = generate_nba_statistics(dataframes)
nba_statistics.to_csv(statistics_output_dir, index=False, encoding='utf-8')
print(f"统计信息已保存到: {statistics_output_dir}")

#### 定义查询构造参数

In [None]:
#### 定义查询构造参数

max_filters = 3
min_rows = 2  # NBA数据相对较少，降低最小行数要求
max_select = 4
limit_list = [1, 2, 3, 5, 8, 10, 15]

# 各类查询数量
sample_sf = 10         # SELECT FROM (单表)
sample_sfj = 10        # SELECT FROM JOIN
sample_sfw = 10        # SELECT FROM WHERE (单表)
sample_sfwj = 100       # SELECT FROM WHERE JOIN
sample_sfwt = 15       # SELECT FROM WHERE TOP-K (单表)
sample_sfwtj = 20      # SELECT FROM WHERE TOP-K JOIN
sample_sfwg = 15       # SELECT FROM WHERE GROUP BY (单表)
sample_sfwgj = 20      # SELECT FROM WHERE GROUP BY JOIN
sample_sfwa = 15       # SELECT FROM WHERE AGGREGATION (单表)
sample_sfwaj = 20      # SELECT FROM WHERE AGGREGATION JOIN
sample_sfwga = 20      # SELECT FROM WHERE GROUP BY AGGREGATION (单表)
sample_sfwgaj = 25     # SELECT FROM WHERE GROUP BY AGGREGATION JOIN
sample_sfwgat = 15     # SELECT FROM WHERE GROUP BY AGGREGATION TOP-K (单表)
sample_sfwgatj = 20    # SELECT FROM WHERE GROUP BY AGGREGATION TOP-K JOIN

#### 定义Filter执行方法

In [None]:
#### 定义Filter执行方法

def non_numerical_equal_to(value, condition):
    try:
        return str(value).lower().strip() == str(condition).lower().strip()
    except:
        return False

def non_numerical_equal_to_with_split(cell, condition):
    if pd.isna(cell):
        return False
    cell_str = str(cell)
    if '||' in cell_str:
        return condition in [v.strip() for v in cell_str.split('||')]
    elif ',' in cell_str:
        return condition in [v.strip() for v in cell_str.split(',')]
    return non_numerical_equal_to(cell, condition)

def number_greater_than(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) > float(condition)
    except:
        return False

def number_less_than(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) < float(condition)
    except:
        return False

def number_equal_to(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) == float(condition)
    except:
        return False

def number_greater_equal(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) >= float(condition)
    except:
        return False

def number_less_equal(value, condition):
    if pd.isna(value):
        value = 0
    try:
        return float(value) <= float(condition)
    except:
        return False

#### 均匀采样函数

In [None]:
#### 均匀采样函数

def balanced_sample(filters, sample_size=60, random_seed=None):
    if random_seed is not None:
        random.seed(random_seed)
    
    if len(filters) <= sample_size:
        return filters
    
    table_to_filters = defaultdict(list)
    for filter_item in filters:
        table = filter_item[0]
        table_to_filters[table].append(filter_item)
    
    tables = list(table_to_filters.keys())
    per_table = sample_size // len(tables)
    remainder = sample_size % len(tables)
    
    sampled_filters = []
    for i, table in enumerate(tables):
        table_filters = table_to_filters[table]
        table_sample_size = per_table + (1 if i < remainder else 0)
        table_sample_size = min(table_sample_size, len(table_filters))
        
        sampled = random.sample(table_filters, table_sample_size)
        sampled_filters.extend(sampled)
    
    return sampled_filters

#### 汇总可取属性值

In [None]:
#### 汇总可取属性值

def get_sample_values_by_selectivity(dataframes, min_rows=2, max_rows=None):
    """
    根据selectivity获取属性值，确保生成的条件有意义
    """
    attr_value_dict = {}
    
    for table_name, df in tqdm(dataframes.items(), desc="获取属性值"):
        if df.empty:
            continue
            
        attr_value_dict[table_name] = {}
        total_rows = len(df)
        if max_rows is None:
            max_rows = total_rows // 2  # 默认最大返回一半行数
        
        for attr in tqdm(table_attributes[table_name].keys(), desc=f"处理{table_name}属性值", leave=False):
            if attr not in df.columns:
                continue
                
            non_null_values = df[attr].dropna()
            if len(non_null_values) == 0:
                continue
                
            values_list = []
            
            if attr in numerical_attr_list:
                try:
                    # 更严格的数值转换和清理
                    numeric_values = pd.to_numeric(non_null_values, errors='coerce').dropna()
                    
                    # 检查数据质量
                    if len(numeric_values) == 0:
                        print(f"    {table_name}.{attr}: 无法转换为数值，跳过")
                        continue
                    
                    # 检查数据类型一致性
                    original_count = len(non_null_values)
                    numeric_count = len(numeric_values)
                    if numeric_count < original_count * 0.8:  # 如果80%以上的数据无法转换为数值
                        print(f"    {table_name}.{attr}: 数据质量差({numeric_count}/{original_count}可转换)，当作文本处理")
                        values_list = []  # 标记为处理失败，后续当文本处理
                    else:
                        # 数据质量好，继续处理
                        numeric_values = pd.to_numeric(non_null_values, errors='coerce').dropna()
                        if len(numeric_values) > 0:
                            # 检查是否为整数类型的属性
                            integer_attrs = ["population", "area", "age", "own_year", "founded_year", 
                                            "championship", "draft_pick", "draft_year", "nba_championships", 
                                            "mvp_awards", "olympic_gold_medals", "fiba_world_cup"]
                        
                        # 1. 先添加精确值（选择率合适的）
                        value_counts = pd.to_numeric(df[attr], errors='coerce').value_counts().dropna()
                        for value, count in value_counts.items():
                            if pd.notna(value) and min_rows <= count <= max_rows:
                                if attr in integer_attrs:
                                    try:
                                        value = int(float(value))
                                    except:
                                        continue
                                values_list.append({
                                    'value': value,
                                    'count': count,
                                    'selectivity': count / total_rows
                                })
                        
                        # 2. 添加范围查询（基于分位数）
                        quantiles = [0.2, 0.4, 0.6, 0.8]
                        for q in quantiles:
                            threshold = numeric_values.quantile(q)
                            if attr in integer_attrs:
                                threshold = int(threshold)
                            
                            try:
                                # 确保对比的是纯数值列，避免类型错误
                                clean_numeric = pd.to_numeric(df[attr], errors='coerce').dropna()
                                
                                # >= threshold
                                ge_count = (clean_numeric >= threshold).sum()
                                if min_rows <= ge_count <= max_rows:
                                    values_list.append({
                                        'value': threshold,
                                        'operator': '>=',
                                        'count': ge_count,
                                        'selectivity': ge_count / total_rows
                                    })
                                
                                # <= threshold  
                                le_count = (clean_numeric <= threshold).sum()
                                if min_rows <= le_count <= max_rows:
                                    values_list.append({
                                        'value': threshold,
                                        'operator': '<=',
                                        'count': le_count,
                                        'selectivity': le_count / total_rows
                                    })
                            except Exception as threshold_error:
                                print(f"    跳过阈值 {threshold}: {threshold_error}")
                                continue
                        
                        # 如果没有合适的值，使用默认值
                        if not values_list:
                            default_values = [0, 1, 5, 10] if attr in integer_attrs else [0.0, 1.0, 5.0, 10.0]
                            for val in default_values:
                                try:
                                    # 检查这些默认值的实际效果，使用清洁的数值列
                                    clean_numeric = pd.to_numeric(df[attr], errors='coerce').dropna()
                                    ge_count = (clean_numeric >= val).sum()
                                    if ge_count >= min_rows:
                                        values_list.append({
                                            'value': val,
                                            'operator': '>=',
                                            'count': ge_count,
                                            'selectivity': ge_count / total_rows
                                        })
                                        break
                                except Exception as default_error:
                                    print(f"    默认值 {val} 处理失败: {default_error}")
                                    continue
                        else:
                            # 无法转换为数值，当作文本处理
                            values_list = []
                        
                except Exception as e:
                    print(f"处理数值属性 {table_name}.{attr} 时出错: {e}")
                    print(f"    数据样本: {non_null_values.head().tolist()}")
                    values_list = []
            
            # 处理非数值属性
            if not values_list:  # 如果是非数值属性，或数值属性处理失败
                if attr in multi_value_attributes_list:
                    # 处理多值属性
                    all_values = []
                    for cell in non_null_values:
                        cell_str = str(cell)
                        if '||' in cell_str:
                            split_values = [v.strip() for v in cell_str.split('||') if v.strip()]
                            all_values.extend(split_values)
                        elif ',' in cell_str:
                            split_values = [v.strip() for v in cell_str.split(',') if v.strip()]
                            all_values.extend(split_values)
                        else:
                            all_values.append(cell_str.strip())
                    
                    # 统计每个拆分值的出现次数
                    from collections import Counter
                    value_counts = Counter(all_values)
                    
                    for value, count in value_counts.items():
                        if min_rows <= count <= max_rows:
                            values_list.append({
                                'value': value,
                                'count': count,
                                'selectivity': count / total_rows
                            })
                else:
                    # 普通文本属性
                    value_counts = df[attr].value_counts()
                    for value, count in value_counts.items():
                        if pd.notna(value) and min_rows <= count <= max_rows:
                            values_list.append({
                                'value': value,
                                'count': count,
                                'selectivity': count / total_rows
                            })
            
            # 选择多样化的值（按selectivity排序后均匀选择）
            if values_list:
                values_list.sort(key=lambda x: x['selectivity'])
                selected = []
                step = max(1, len(values_list) // 10)  # 最多选10个值
                for i in range(0, len(values_list), step):
                    selected.append(values_list[i])
                    if len(selected) >= 12:
                        break
                
                attr_value_dict[table_name][attr] = selected
                print(f"  {table_name}.{attr}: 选择了 {len(selected)} 个值，selectivity范围 "
                      f"{min(v['selectivity'] for v in selected):.3f}-{max(v['selectivity'] for v in selected):.3f}")
    
    return attr_value_dict

# 使用改进的函数获取属性值
attr_value_dict = get_sample_values_by_selectivity(dataframes, min_rows=min_rows)

print("\n属性值示例:")
for table_name, table_values in attr_value_dict.items():
    print(f"\n{table_name}表:")
    for attr, values in list(table_values.items())[:2]:  # 只显示前2个属性
        if values:
            # 显示前3个值的详细信息
            sample_values = values[:3]
            for i, v in enumerate(sample_values):
                if isinstance(v, dict):
                    op_info = f" {v.get('operator', '==')}" if 'operator' in v else " =="
                    print(f"    {attr}[{i}]: {v['value']}{op_info} -> {v['count']}行 (选择率:{v['selectivity']:.3f})")
                else:
                    print(f"    {attr}[{i}]: {v}")
            if len(values) > 3:
                print(f"    ... 还有 {len(values)-3} 个值")
        else:
            print(f"  {attr}: 无有效值")

#### 枚举所有可能的Filter

In [None]:
# 修复build_filter_dict中的条件字符串生成

def build_filter_dict_fixed(dataframes, attr_value_dict):
    filter_dict = {}
    
    for table_name, df in tqdm(dataframes.items(), desc="处理表"):
        if df.empty:
            continue
            
        filter_dict[table_name] = {}
        
        for attr in tqdm(table_attributes[table_name].keys(), desc=f"处理{table_name}属性", leave=False):
            if attr not in df.columns or attr not in attr_value_dict[table_name]:
                continue
                
            condition_dict = {}
            
            for value_info in attr_value_dict[table_name][attr]:
                try:
                    # 提取实际值
                    if isinstance(value_info, dict) and 'value' in value_info:
                        actual_value = value_info['value']
                        operator = value_info.get('operator', '==')
                    else:
                        actual_value = value_info
                        operator = '=='
                    
                    # 生成正确的条件字符串
                    if operator == '>=':
                        condition_str = f">={actual_value}"
                        mask = pd.to_numeric(df[attr], errors='coerce') >= actual_value
                    elif operator == '<=':
                        condition_str = f"<={actual_value}"
                        mask = pd.to_numeric(df[attr], errors='coerce') <= actual_value
                    elif operator == '>':
                        condition_str = f">{actual_value}"
                        mask = pd.to_numeric(df[attr], errors='coerce') > actual_value
                    elif operator == '<':
                        condition_str = f"<{actual_value}"
                        mask = pd.to_numeric(df[attr], errors='coerce') < actual_value
                    else:  # operator == '=='
                        if attr in numerical_attr_list:
                            condition_str = f"=={actual_value}"
                            mask = pd.to_numeric(df[attr], errors='coerce') == actual_value
                        else:
                            condition_str = f"=='{actual_value}'"  # 文本值加引号
                            mask = df[attr] == actual_value
                    
                    # 获取满足条件的行索引
                    result_indices = df[mask.fillna(False)].index.tolist()
                    
                    if len(result_indices) >= min_rows:
                        condition_dict[condition_str] = result_indices
                        
                except Exception as e:
                    print(f"处理条件 {table_name}.{attr} = {value_info} 时出错: {e}")
                    continue
                    
            if condition_dict:
                filter_dict[table_name][attr] = condition_dict
    
    return filter_dict

# 修复WHERE条件组合中的表达式生成
# 1. 修复WHERE条件生成 - 在build_where_combinations中直接生成正确格式

def build_where_combinations_simple_fix(filter_dict, max_combinations=1200):
    all_filters = []
    
    for table_name, table_filters in filter_dict.items():
        for attr, conditions in table_filters.items():
            for cond_str, indices in conditions.items():
                if len(indices) >= min_rows:
                    # 直接生成最终的WHERE表达式，提取实际值
                    if '==' in cond_str:
                        value = cond_str.split('==')[1]
                        # 去掉引号并检查是否为空
                        if value.startswith("'") and value.endswith("'"):
                            actual_value = value[1:-1].strip()
                            if not actual_value or actual_value == ' ':  # 跳过空值
                                continue
                            where_expr = f"{attr} = '{actual_value}'"
                        else:
                            try:
                                actual_value = float(value) if '.' in value else int(value)
                                where_expr = f"{attr} = {actual_value}"
                            except:
                                continue
                    elif '>=' in cond_str:
                        value = cond_str.split('>=')[1]
                        try:
                            actual_value = float(value) if '.' in value else int(value)
                            where_expr = f"{attr} >= {actual_value}"
                        except:
                            continue
                    elif '<=' in cond_str:
                        value = cond_str.split('<=')[1]
                        try:
                            actual_value = float(value) if '.' in value else int(value)
                            where_expr = f"{attr} <= {actual_value}"
                        except:
                            continue
                    elif '>' in cond_str:
                        value = cond_str.split('>')[1]
                        try:
                            actual_value = float(value) if '.' in value else int(value)
                            where_expr = f"{attr} > {actual_value}"
                        except:
                            continue
                    elif '<' in cond_str:
                        value = cond_str.split('<')[1]
                        try:
                            actual_value = float(value) if '.' in value else int(value)
                            where_expr = f"{attr} < {actual_value}"
                        except:
                            continue
                    else:
                        continue
                    
                    all_filters.append((table_name, attr, where_expr, set(indices)))
    
    # 组合逻辑保持不变
    valid_where = []
    seen_expressions = set()
    
    sampled_filters = balanced_sample(all_filters, sample_size=80, random_seed=42)
    
    for n in tqdm(range(1, 6), desc="生成Filter组合"):
        combinations = list(itertools.combinations(sampled_filters, n))
        random.shuffle(combinations)
        
        if n == 1:
            max_combos_this_round = int(max_combinations * 2 / 10)
        elif n == 2:
            max_combos_this_round = int(max_combinations * 3 / 10)
        elif n == 3:
            max_combos_this_round = int(max_combinations * 3 / 10)
        elif n == 4:
            max_combos_this_round = int(max_combinations * 1 / 10)
        elif n == 5:
            max_combos_this_round = int(max_combinations * 1 / 10)
        else:
            max_combos_this_round = 50
            
        combo_count = 0
        for combo in tqdm(combinations, desc=f"{n}个Filter组合", leave=False):
            if combo_count >= max_combos_this_round:
                break
                
            for op in ['AND', 'OR']:
                if n == 1:
                    where_clause = combo[0][2]  # 已经是正确格式的where_expr
                    result_indices = combo[0][3]
                    tables_involved = [combo[0][0]]
                else:
                    where_expressions = [item[2] for item in combo]  # 已经是正确格式
                    where_clause = f" {op} ".join(where_expressions)
                    
                    combo_tables = [item[0] for item in combo]
                    if len(set(combo_tables)) == 1:
                        if op == 'AND':
                            result_indices = set.intersection(*[item[3] for item in combo])
                        else:
                            result_indices = set.union(*[item[3] for item in combo])
                    else:
                        if op == 'AND':
                            intersection = set.intersection(*[item[3] for item in combo])
                            if len(intersection) >= min_rows:
                                result_indices = intersection
                            else:
                                result_indices = max([item[3] for item in combo], key=len)
                        else:
                            result_indices = set.union(*[item[3] for item in combo])
                    
                    tables_involved = list(set(combo_tables))
                
                min_rows_required = min_rows if n <= 2 else max(1, min_rows // 2)
                
                if len(result_indices) >= min_rows_required and where_clause not in seen_expressions:
                    seen_expressions.add(where_clause)
                    
                    query_dict = {
                        "WHERE Indices": list(result_indices),
                        "WHERE Total Rows": len(result_indices),
                        "WHERE": where_clause,  # 现在这里存储的是正确格式的条件
                        "Tables": tables_involved,
                        "Combination": [[item[0], item[1], item[2]] for item in combo],
                        "Operator": op if n > 1 else "NONE",
                        "Filter Count": n
                    }
                    valid_where.append(query_dict)
            
            combo_count += 1
    
    return valid_where

def filter_empty_values(attr_value_dict):
    """
    过滤掉空值和无效值
    """
    filtered_dict = {}
    
    for table_name, table_attrs in attr_value_dict.items():
        filtered_dict[table_name] = {}
        
        for attr, values in table_attrs.items():
            filtered_values = []
            
            for value_info in values:
                if isinstance(value_info, dict) and 'value' in value_info:
                    actual_value = value_info['value']
                    # 跳过空值
                    if isinstance(actual_value, str):
                        if actual_value.strip() == '' or actual_value.strip() == ' ':
                            continue
                    filtered_values.append(value_info)
                else:
                    # 直接值
                    if isinstance(value_info, str):
                        if value_info.strip() == '' or value_info.strip() == ' ':
                            continue
                    filtered_values.append(value_info)
            
            if filtered_values:
                filtered_dict[table_name][attr] = filtered_values
    
    return filtered_dict

# 使用修复
print("过滤空值...")
attr_value_dict = filter_empty_values(attr_value_dict)

print("重新构建filter_dict...")
filter_dict = build_filter_dict_fixed(dataframes, attr_value_dict)

print("生成WHERE条件...")
valid_where = build_where_combinations_simple_fix(filter_dict, max_combinations=2000)

print(f"\n修复后的WHERE条件示例:")
for i, where_dict in enumerate(valid_where[:3]):
    print(f"{i+1}. {where_dict['WHERE']}")
    print(f"   -> {where_dict['WHERE Total Rows']}行\n")

#### 构建WHERE条件组合

#### 定义SCHEMA创建函数

In [None]:
#### 定义SCHEMA创建函数

def create_schema(attr_desc_dict, query_attr_list):
    schema = {}
    for key in query_attr_list:
        if key in non_numerical_attr_list:
            schema[key] = ["VARCHAR(255)", attr_desc_dict.get(key, "")]
        elif key in numerical_attr_list:
            schema[key] = ["FLOAT", attr_desc_dict.get(key, "")]
        else:
            schema[key] = ["VARCHAR(255)", attr_desc_dict.get(key, "")]
    return schema

def create_multi_table_schema(tables, selected_attrs):
    schema = {}
    
    for table, attr in selected_attrs:
        column_name = f"{table}.{attr}"
        
        if attr in numerical_attr_list:
            schema[column_name] = ["FLOAT", f"{table} table {attr} field"]
        else:
            schema[column_name] = ["VARCHAR(255)", f"{table} table {attr} field"]
    
    return schema

# 通用的WHERE条件选择函数
def select_where_by_ratio(valid_where, total_needed, single_table_only=True, multi_table_ratio=0.3):
    """
    按比例选择WHERE条件
    multi_table_ratio: 多filter的比例 (默认30%单filter, 70%多filter)
    """
    if single_table_only:
        single_filter = [w for w in valid_where if w.get("Filter Count", 1) == 1 and len(w["Tables"]) == 1]
        multi_filter = [w for w in valid_where if w.get("Filter Count", 1) > 1 and len(w["Tables"]) == 1]
    else:
        single_filter = [w for w in valid_where if w.get("Filter Count", 1) == 1]
        multi_filter = [w for w in valid_where if w.get("Filter Count", 1) > 1]
    
    single_count = int(total_needed * (1 - multi_table_ratio))
    multi_count = total_needed - single_count
    
    selected = multi_filter[:multi_count] + single_filter[:single_count]
    return selected[:total_needed]

def extract_attrs_from_where(where_clause):
    """从WHERE子句中提取属性名"""
    import re
    if not where_clause:
        return []
    attr_matches = re.findall(r'(\w+)\s*[><=!]', where_clause)
    return attr_matches

#### 构建SELECT FROM（单表）

In [None]:
#### 构建SELECT FROM（单表）

def generate_select_from_queries(dataframes):
    """
    生成简单的SELECT FROM查询（单表，无WHERE条件）
    """
    queries = []
    
    # 为每个表生成查询
    table_names = list(dataframes.keys())
    if len(table_names) == 0:
        print("警告：没有加载到任何表数据")
        return queries
        
    queries_per_table = sample_sf // len(table_names)
    remainder = sample_sf % len(table_names)
    
    for i, table_name in enumerate(tqdm(table_names, desc="生成SF查询")):
        if dataframes[table_name].empty:
            continue
            
        table_query_count = queries_per_table + (1 if i < remainder else 0)
        
        for _ in range(table_query_count):
            available_attrs = list(table_attributes[table_name].keys())
            if not available_attrs:
                continue
                
            # 随机选择属性数量（1到max_select之间）
            num_attrs = random.randint(1, min(max_select, len(available_attrs)))
            selected_attrs = random.sample(available_attrs, num_attrs)
            
            # 计算表的总行数
            total_rows = len(dataframes[table_name])
            
            query_dict = {
                "Type": "SF",
                "SELECT": selected_attrs,
                "FROM": [table_name],
                "WHERE Indices": list(range(total_rows)),  # 所有行的索引
                "WHERE Total Rows": total_rows,  # 表的总行数
                "WHERE": "",  # 无WHERE条件
                "Tables": [table_name],
                "Combination": [],  # 无filter组合
                "Operator": "NONE",
                "Filter Count": 0,
                "SCHEMA": create_schema(attr_desc_dict, selected_attrs)
            }
            
            queries.append(query_dict)
    
    return queries

sf_queries = generate_select_from_queries(dataframes)
print(f"生成SF查询: {len(sf_queries)} 个")

#### 构建SELECT FROM JOIN查询（多表）

In [None]:
#### 构建SELECT FROM JOIN查询（多表）

def generate_select_from_join_queries(dataframes):
    """
    生成SELECT FROM JOIN查询（多表，无WHERE条件）
    """
    queries = []
    
    # 定义可能的JOIN组合 - 基于实际的JOIN关系
    join_combinations = [
        ['team', 'city'],           # team-city 
        ['owner', 'team'],          # owner-team
        ['player', 'team'],         # player-team
        ['team', 'city', 'owner'],  # team-city-owner
        ['team', 'player', 'owner'] # team-player-owner
    ]
    
    # 过滤掉不存在的表
    valid_join_combinations = []
    for combo in join_combinations:
        if all(table in dataframes and not dataframes[table].empty for table in combo):
            valid_join_combinations.append(combo)
    
    if not valid_join_combinations:
        print("警告：没有有效的JOIN组合")
        return queries
    
    # 为每种JOIN组合生成查询
    queries_per_combo = sample_sfj // len(valid_join_combinations)
    remainder = sample_sfj % len(valid_join_combinations)
    
    for i, join_combo in enumerate(tqdm(valid_join_combinations, desc="生成SFJ查询")):
        combo_query_count = queries_per_combo + (1 if i < remainder else 0)
        
        for _ in range(combo_query_count):
            # 从每个表中选择属性
            available_attrs = []
            for table in join_combo:
                # 每个表最多选择3个属性，避免SELECT子句过长
                table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:4]]
                available_attrs.extend(table_attrs)
            
            if not available_attrs:
                continue
                
            # 随机选择属性数量（2到max_select之间）
            num_attrs = random.randint(2, min(max_select, len(available_attrs)))
            selected_attrs = random.sample(available_attrs, num_attrs)
            select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
            
            # 估算JOIN结果行数：使用参与表中行数最小的表作为估算
            estimated_rows = min(len(dataframes[table]) for table in join_combo)
            
            query_dict = {
                "Type": "SFJ",
                "SELECT": select_clause,
                "FROM": join_combo,
                "JOIN_TYPE": "INNER JOIN",
                "WHERE Indices": [],  # JOIN查询的具体索引需要执行才能确定
                "WHERE Total Rows": estimated_rows,  # 使用估算的行数
                "WHERE": "",  # 无WHERE条件
                "Tables": join_combo,
                "Combination": [],  # 无filter组合
                "Operator": "NONE", 
                "Filter Count": 0,
                "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
            }
            
            queries.append(query_dict)
    
    return queries

sfj_queries = generate_select_from_join_queries(dataframes)
print(f"生成SFJ查询: {len(sfj_queries)} 个")

#### 构建SELECT FROM WHERE查询（单表）

In [None]:
#### 构建SELECT FROM WHERE查询（单表）

def generate_select_from_where_queries(valid_where, dataframes):
    queries = []
    
    selected_where = select_where_by_ratio(valid_where, sample_sfw, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in tqdm(selected_where, desc="生成SFW查询"):
        table_name = where_dict["Tables"][0]
        
        if table_name not in dataframes or dataframes[table_name].empty:
            continue
            
        available_attrs = list(table_attributes[table_name].keys())[:8]
        if not available_attrs:
            continue
            
        selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
        
        # 修复：从WHERE字符串中提取属性，而不是从Combination
        query_attr_list = selected_attrs.copy()
        where_attrs = extract_attrs_from_where(where_dict.get("WHERE", ""))
        for attr in where_attrs:
            if attr not in query_attr_list:
                query_attr_list.append(attr)
        
        query_dict = where_dict.copy()
        query_dict.update({
            "Type": "SFW",
            "SELECT": selected_attrs,
            "FROM": [table_name],
            "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
        })
        
        queries.append(query_dict)
    
    return queries

sfw_queries = generate_select_from_where_queries(valid_where, dataframes)
print(f"生成SFW查询: {len(sfw_queries)} 个")

#### 构建多表JOIN查询

In [None]:
#### 构建多表JOIN查询

def generate_join_queries(valid_where, dataframes):
    queries = []
    
    # 按比例选择WHERE条件
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwj//2, single_table_only=False, multi_table_ratio=0.6)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwj//2, single_table_only=True, multi_table_ratio=0.8)
    
    # 多表WHERE条件的JOIN查询
    for where_dict in tqdm(multi_table_where, desc="生成多表JOIN查询"):
        tables_involved = where_dict["Tables"] if len(where_dict["Tables"]) > 1 else ['team', 'city']
        
        # 检查表是否存在
        if not all(table in dataframes and not dataframes[table].empty for table in tables_involved):
            continue
            
        available_attrs = []
        for table in tables_involved:
            table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
            available_attrs.extend(table_attrs)
        
        if not available_attrs:
            continue
            
        selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
        select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
        
        query_dict = where_dict.copy()
        query_dict.update({
            "Type": "SFWJ",
            "SELECT": select_clause,
            "FROM": tables_involved,
            "JOIN_TYPE": "INNER JOIN",
            "SCHEMA": create_multi_table_schema(tables_involved, selected_attrs)
        })
        
        queries.append(query_dict)
    
    # 单表WHERE条件应用到JOIN查询
    join_combinations = [['team', 'city'], ['owner', 'team'], ['player', 'team']]
    
    # 过滤有效的JOIN组合
    valid_join_combinations = []
    for combo in join_combinations:
        if all(table in dataframes and not dataframes[table].empty for table in combo):
            valid_join_combinations.append(combo)
    
    for where_dict in tqdm(single_table_for_join[:8], desc="生成单表->JOIN查询", leave=False):
        for join_combo in valid_join_combinations[:2]:
            available_attrs = []
            for table in join_combo:
                table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
                available_attrs.extend(table_attrs)
            
            if not available_attrs:
                continue
                
            selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
            select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWJ",
                "SELECT": select_clause,
                "FROM": join_combo,
                "JOIN_TYPE": "INNER JOIN",
                "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
            })
            
            queries.append(query_dict)
    
    return queries

join_queries = generate_join_queries(valid_where, dataframes)
print(f"生成JOIN查询: {len(join_queries)} 个")

#### 构建SELECT FROM WHERE TOP-K查询

In [None]:
#### 构建SELECT FROM WHERE TOP-K查询

def generate_topk_queries(valid_where, dataframes):
    queries = []
    order_options = ['ASC', 'DESC']
    
    single_table_where = select_where_by_ratio(valid_where, sample_sfwt, single_table_only=True, multi_table_ratio=0.6)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        if table_name not in dataframes or dataframes[table_name].empty:
            continue
            
        available_attrs = list(table_attributes[table_name].keys())
        numerical_attrs = [attr for attr in available_attrs if attr in numerical_attr_list]
        
        if numerical_attrs:
            selected_attrs = random.sample(available_attrs[:8], min(max_select, len(available_attrs[:8])))
            order_column = random.choice(numerical_attrs)
            order_type = random.choice(order_options)
            limit_value = random.choice(limit_list)
            
            query_attr_list = selected_attrs + [order_column]
            where_attrs = extract_attrs_from_where(where_dict.get("WHERE", ""))
            for attr in where_attrs:
                if attr not in query_attr_list:
                    query_attr_list.append(attr)
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWT",
                "SELECT": selected_attrs,
                "FROM": [table_name],
                "ORDER BY": [order_column, order_type],
                "LIMIT": limit_value,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    # JOIN TOP-K查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwtj//2, single_table_only=False, multi_table_ratio=0.7)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwtj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['team', 'city'], ['owner', 'team'], ['player', 'team']]
    
    # 过滤有效的JOIN组合
    valid_join_combinations = []
    for combo in join_combinations:
        if all(table in dataframes and not dataframes[table].empty for table in combo):
            valid_join_combinations.append(combo)
    
    for where_dict in multi_table_where + single_table_for_join[:8]:
        for join_combo in valid_join_combinations[:2]:
            available_attrs = []
            numerical_join_attrs = []
            
            for table in join_combo:
                table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
                available_attrs.extend(table_attrs)
                
                table_numerical = [(table, attr) for attr in table_attributes[table].keys() 
                                 if attr in numerical_attr_list]
                numerical_join_attrs.extend(table_numerical)
            
            if numerical_join_attrs and available_attrs:
                selected_attrs = random.sample(available_attrs, min(max_select, len(available_attrs)))
                select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
                
                order_table, order_attr = random.choice(numerical_join_attrs)
                order_column = f"{order_table}.{order_attr}"
                order_type = random.choice(order_options)
                limit_value = random.choice(limit_list)
                
                query_dict = where_dict.copy()
                query_dict.update({
                    "Type": "SFWTJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "ORDER BY": [order_column, order_type],
                    "LIMIT": limit_value,
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                })
                
                queries.append(query_dict)
    
    return queries

topk_queries = generate_topk_queries(valid_where, dataframes)
print(f"生成TOP-K查询: {len(topk_queries)} 个")

#### 构建SELECT FROM WHERE GROUP BY查询

In [None]:
#### 构建SELECT FROM WHERE GROUP BY查询

def generate_groupby_queries(valid_where, dataframes):
    queries = []
    
    # 按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwg, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        if table_name not in dataframes or dataframes[table_name].empty:
            continue
            
        category_attrs = category_attr.get(table_name, [])
        if category_attrs:
            groupby_columns = random.sample(category_attrs, 1)
            selected_attrs = random.sample(list(table_attributes[table_name].keys())[:8], 
                                         min(max_select, len(list(table_attributes[table_name].keys())[:8])))
            
            if groupby_columns[0] not in selected_attrs:
                selected_attrs = groupby_columns + selected_attrs[:max_select-1]
            
            query_attr_list = selected_attrs.copy()
            where_attrs = extract_attrs_from_where(where_dict.get("WHERE", ""))
            for attr in where_attrs:
                if attr not in query_attr_list:
                    query_attr_list.append(attr)
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWG",
                "SELECT": selected_attrs,
                "FROM": [table_name],
                "GROUP BY": groupby_columns,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    # JOIN GROUP BY查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwgj//2, single_table_only=False, multi_table_ratio=0.6)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwgj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['team', 'city'], ['owner', 'team']]
    
    # 过滤有效的JOIN组合
    valid_join_combinations = []
    for combo in join_combinations:
        if all(table in dataframes and not dataframes[table].empty for table in combo):
            valid_join_combinations.append(combo)
    
    for where_dict in multi_table_where + single_table_for_join[:8]:
        for join_combo in valid_join_combinations:
            join_category_attrs = []
            for table in join_combo:
                table_categories = [(table, attr) for attr in category_attr.get(table, [])]
                join_category_attrs.extend(table_categories)
            
            if join_category_attrs:
                available_attrs = []
                for table in join_combo:
                    table_attrs = [(table, attr) for attr in list(table_attributes[table].keys())[:3]]
                    available_attrs.extend(table_attrs)
                
                if not available_attrs:
                    continue
                    
                groupby_table, groupby_attr = random.choice(join_category_attrs)
                groupby_column = f"{groupby_table}.{groupby_attr}"
                
                selected_attrs = random.sample(available_attrs, min(max_select-1, len(available_attrs)))
                selected_attrs.append((groupby_table, groupby_attr))
                select_clause = [f"{table}.{attr}" for table, attr in selected_attrs]
                
                query_dict = where_dict.copy()
                query_dict.update({
                    "Type": "SFWGJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "GROUP BY": [groupby_column],
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                })
                
                queries.append(query_dict)
    
    return queries

groupby_queries = generate_groupby_queries(valid_where, dataframes)
print(f"生成GROUP BY查询: {len(groupby_queries)} 个")

#### 构建SELECT FROM WHERE AGGREGATION查询

In [None]:
#### 构建SELECT FROM WHERE AGGREGATION查询

def generate_aggregation_queries(valid_where, dataframes):
    queries = []
    aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
    
    # 按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwa, single_table_only=True, multi_table_ratio=0.6)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        if table_name not in dataframes or dataframes[table_name].empty:
            continue
            
        numerical_attrs = numerical_attr.get(table_name, [])
        
        if numerical_attrs:
            agg_column = random.choice(numerical_attrs)
            function = random.choice(aggregation_functions[1:])
            numerical = True
        else:
            agg_column = "*"
            function = 'COUNT'
            numerical = False
        
        query_attr_list = [agg_column] if agg_column != "*" else []
        where_attrs = extract_attrs_from_where(where_dict.get("WHERE", ""))
        for attr in where_attrs:
            if attr not in query_attr_list:
                query_attr_list.append(attr)
        
        query_dict = where_dict.copy()
        query_dict.update({
            "Type": "SFWA",
            "SELECT": [f"{function}({agg_column})"],
            "FROM": [table_name],
            "AGGREGATION": [agg_column],
            "AGGREGATION Function": function,
            "Numerical": numerical,
            "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
        })
        
        queries.append(query_dict)
    
    # JOIN AGGREGATION查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwaj//2, single_table_only=False, multi_table_ratio=0.7)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwaj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['team', 'city'], ['owner', 'team']]
    
    # 过滤有效的JOIN组合
    valid_join_combinations = []
    for combo in join_combinations:
        if all(table in dataframes and not dataframes[table].empty for table in combo):
            valid_join_combinations.append(combo)
    
    for where_dict in multi_table_where + single_table_for_join[:8]:
        for join_combo in valid_join_combinations:
            join_numerical_attrs = []
            for table in join_combo:
                table_numerical = [(table, attr) for attr in numerical_attr.get(table, [])]
                join_numerical_attrs.extend(table_numerical)
            
            if join_numerical_attrs:
                agg_table, agg_attr = random.choice(join_numerical_attrs)
                agg_column = f"{agg_table}.{agg_attr}"
                function = random.choice(aggregation_functions[1:])
                numerical = True
            else:
                agg_column = "*"
                function = 'COUNT'
                numerical = False
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWAJ",
                "SELECT": [f"{function}({agg_column})"],
                "FROM": join_combo,
                "JOIN_TYPE": "INNER JOIN",
                "AGGREGATION": [agg_column],
                "AGGREGATION Function": function,
                "Numerical": numerical,
                "SCHEMA": create_multi_table_schema(join_combo, [(agg_table, agg_attr)] if agg_column != "*" else [])
            })
            
            queries.append(query_dict)
    
    return queries

aggregation_queries = generate_aggregation_queries(valid_where, dataframes)
print(f"生成AGGREGATION查询: {len(aggregation_queries)} 个")

#### 构建SELECT FROM WHERE GROUP BY AGGREGATION查询

In [None]:

def generate_groupby_aggregation_queries(valid_where, dataframes):
    queries = []
    aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
    
    # 按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwga, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        if table_name not in dataframes or dataframes[table_name].empty:
            continue
            
        category_attrs = category_attr.get(table_name, [])
        numerical_attrs = numerical_attr.get(table_name, [])
        
        if category_attrs:
            groupby_columns = random.sample(category_attrs, 1)
            
            if numerical_attrs:
                agg_column = random.choice(numerical_attrs)
                function = random.choice(aggregation_functions[1:])
                numerical = True
            else:
                agg_column = "*"
                function = 'COUNT'
                numerical = False
            
            select_clause = groupby_columns + [f"{function}({agg_column})"]
            
            query_attr_list = groupby_columns + ([agg_column] if agg_column != "*" else [])
            where_attrs = extract_attrs_from_where(where_dict.get("WHERE", ""))
            for attr in where_attrs:
                if attr not in query_attr_list:
                    query_attr_list.append(attr)
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWGA",
                "SELECT": select_clause,
                "FROM": [table_name],
                "GROUP BY": groupby_columns,
                "AGGREGATION": [agg_column],
                "AGGREGATION Function": function,
                "Numerical": numerical,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    # JOIN GROUP BY AGGREGATION查询
    multi_table_where = select_where_by_ratio(valid_where, sample_sfwgaj//2, single_table_only=False, multi_table_ratio=0.6)
    single_table_for_join = select_where_by_ratio(valid_where, sample_sfwgaj//2, single_table_only=True, multi_table_ratio=0.8)
    
    join_combinations = [['team', 'city'], ['owner', 'team']]
    
    # 过滤有效的JOIN组合
    valid_join_combinations = []
    for combo in join_combinations:
        if all(table in dataframes and not dataframes[table].empty for table in combo):
            valid_join_combinations.append(combo)
    
    for where_dict in multi_table_where + single_table_for_join[:8]:
        for join_combo in valid_join_combinations:
            join_category_attrs = []
            join_numerical_attrs = []
            
            for table in join_combo:
                table_categories = [(table, attr) for attr in category_attr.get(table, [])]
                join_category_attrs.extend(table_categories)
                
                table_numerical = [(table, attr) for attr in numerical_attr.get(table, [])]
                join_numerical_attrs.extend(table_numerical)
            
            if join_category_attrs:
                groupby_table, groupby_attr = random.choice(join_category_attrs)
                groupby_column = f"{groupby_table}.{groupby_attr}"
                
                if join_numerical_attrs:
                    agg_table, agg_attr = random.choice(join_numerical_attrs)
                    agg_column = f"{agg_table}.{agg_attr}"
                    function = random.choice(aggregation_functions[1:])
                    numerical = True
                    selected_attrs = [(groupby_table, groupby_attr), (agg_table, agg_attr)]
                else:
                    agg_column = "*"
                    function = 'COUNT'
                    numerical = False
                    selected_attrs = [(groupby_table, groupby_attr)]
                
                select_clause = [groupby_column, f"{function}({agg_column})"]
                
                query_dict = where_dict.copy()
                query_dict.update({
                    "Type": "SFWGAJ",
                    "SELECT": select_clause,
                    "FROM": join_combo,
                    "JOIN_TYPE": "INNER JOIN",
                    "GROUP BY": [groupby_column],
                    "AGGREGATION": [agg_column],
                    "AGGREGATION Function": function,
                    "Numerical": numerical,
                    "SCHEMA": create_multi_table_schema(join_combo, selected_attrs)
                })
                
                queries.append(query_dict)
    
    return queries

groupby_aggregation_queries = generate_groupby_aggregation_queries(valid_where, dataframes)
print(f"生成GROUP BY AGGREGATION查询: {len(groupby_aggregation_queries)} 个")

#### 构建SELECT FROM WHERE GROUP BY AGGREGATION TOP-K查询

In [None]:
def generate_groupby_aggregation_topk_queries(valid_where, dataframes):
    queries = []
    aggregation_functions = ['COUNT', 'MAX', 'MIN', 'AVG', 'SUM']
    order_options = ['ASC', 'DESC']
    
    # 按比例选择WHERE条件
    single_table_where = select_where_by_ratio(valid_where, sample_sfwgat, single_table_only=True, multi_table_ratio=0.7)
    
    for where_dict in single_table_where:
        table_name = where_dict["Tables"][0]
        
        if table_name not in dataframes or dataframes[table_name].empty:
            continue
            
        category_attrs = category_attr.get(table_name, [])
        numerical_attrs = numerical_attr.get(table_name, [])
        
        if category_attrs:
            groupby_columns = random.sample(category_attrs, 1)
            
            if numerical_attrs and random.random() > 0.5:
                agg_column = random.choice(numerical_attrs)
                function = random.choice(aggregation_functions[1:])
                numerical = True
                order_column = f"{function}({agg_column})"
            else:
                agg_column = "*"
                function = 'COUNT'
                numerical = False
                order_column = f"COUNT(*)"
            
            order_type = random.choice(order_options)
            limit_value = random.choice(limit_list)
            
            select_clause = groupby_columns + [f"{function}({agg_column})"]
            
            query_attr_list = groupby_columns + ([agg_column] if agg_column != "*" else [])
            where_attrs = extract_attrs_from_where(where_dict.get("WHERE", ""))
            for attr in where_attrs:
                if attr not in query_attr_list:
                    query_attr_list.append(attr)
            
            query_dict = where_dict.copy()
            query_dict.update({
                "Type": "SFWGAT",
                "SELECT": select_clause,
                "FROM": [table_name],
                "GROUP BY": groupby_columns,
                "AGGREGATION": [agg_column],
                "AGGREGATION Function": function,
                "Numerical": numerical,
                "ORDER BY": [order_column, order_type],
                "LIMIT": limit_value,
                "SCHEMA": create_schema(attr_desc_dict, query_attr_list)
            })
            
            queries.append(query_dict)
    
    return queries

groupby_aggregation_topk_queries = generate_groupby_aggregation_topk_queries(valid_where, dataframes)
print(f"生成GROUP BY AGGREGATION TOP-K查询: {len(groupby_aggregation_topk_queries)} 个")

#### SQL生成函数

In [None]:
# 完整的SQL生成函数定义

def generate_sql_query(query_dict):
    """
    根据查询字典生成对应的SQL语句
    """
    # 获取字段，提供默认值
    tables_list = query_dict.get("FROM", ["nba_table"])
    select_fields = query_dict.get("SELECT", ["*"])
    
    select_clause = "SELECT " + ", ".join(select_fields)
    
    if len(tables_list) == 1:
        from_clause = f"FROM {tables_list[0]}"
    else:
        from_clause = f"FROM {tables_list[0]}"
        
        # 修复：使用正确的NBA业务JOIN关系
        for i in range(1, len(tables_list)):
            current_table = tables_list[i]
            prev_table = tables_list[i-1]
            
            join_condition = None
            
            # owner -> team: team字段 = team_name字段
            if (prev_table == "owner" and current_table == "team") or \
               (prev_table == "team" and current_table == "owner"):
                join_condition = "owner.team = team.team_name"
            
            # team -> city: city字段 = city字段 (注意：实际数据中可能是location)
            elif (prev_table == "team" and current_table == "city") or \
                 (prev_table == "city" and current_table == "team"):
                join_condition = "team.city = city.city"
            
            # player -> team: team字段 = team_name字段
            elif (prev_table == "player" and current_table == "team") or \
                 (prev_table == "team" and current_table == "player"):
                join_condition = "player.team = team.team_name"
            
            # 三表JOIN的特殊处理
            elif len(tables_list) == 3:
                if "owner" in tables_list and "team" in tables_list and "player" in tables_list:
                    if i == 2:  # 第三个表的JOIN
                        if current_table == "player":
                            join_condition = "player.team = team.team_name"
                        elif current_table == "owner":
                            join_condition = "owner.team = team.team_name"
                elif "team" in tables_list and "city" in tables_list and "player" in tables_list:
                    if i == 2:
                        if current_table == "city":
                            join_condition = "team.city = city.city"
                        elif current_table == "player":
                            join_condition = "player.team = team.team_name"
            
            if join_condition:
                from_clause += f"\nINNER JOIN {current_table} ON {join_condition}"
            else:
                print(f"Warning: Using fallback join for {prev_table} -> {current_table}")
                # 备选方案：使用业务主键
                from_clause += f"\nINNER JOIN {current_table} ON {tables_list[0]}.team_name = {current_table}.team"
    
    sql_parts = [select_clause, from_clause]
    
    if query_dict.get("WHERE"):
        sql_parts.append(f"WHERE {query_dict['WHERE']}")
    if query_dict.get("GROUP BY"):
        sql_parts.append(f"GROUP BY {', '.join(query_dict['GROUP BY'])}")
    if query_dict.get("ORDER BY"):
        order_col, order_type = query_dict["ORDER BY"]
        sql_parts.append(f"ORDER BY {order_col} {order_type}")
    if query_dict.get("LIMIT"):
        sql_parts.append(f"LIMIT {query_dict['LIMIT']}")
    
    return "\n".join(sql_parts) + ";"

def generate_schema_sql(schema_dict, table_name="NBA_Data"):
    """
    根据SCHEMA字典生成建表SQL语句
    """
    if not schema_dict:
        return f"CREATE TABLE {table_name} (id INTEGER PRIMARY KEY);"
    
    create_table = f"CREATE TABLE {table_name} (\n"
    columns = []
    
    for col_name, (data_type, description) in schema_dict.items():
        comment = f" COMMENT '{description}'" if description else ""
        columns.append(f"    {col_name} {data_type}{comment}")
    
    create_table += ",\n".join(columns)
    create_table += "\n);"
    
    return create_table

def save_nba_queries_to_sql(output_dir="./sql_queries/"):
    """
    读取所有生成的查询，为每个查询生成对应的SQL语句并保存
    """
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # 收集所有查询
    all_queries = []
    query_vars = [
        'sf_queries', 'sfj_queries', 'sfw_queries', 'join_queries', 
        'topk_queries', 'groupby_queries', 'aggregation_queries',
        'groupby_aggregation_queries', 'groupby_aggregation_topk_queries'
    ]
    
    for var_name in query_vars:
        if var_name in globals() and globals()[var_name]:
            all_queries.extend(globals()[var_name])
    
    if not all_queries:
        print("没有找到查询数据")
        return
    
    # 按查询类型分组
    queries_by_type = {}
    for query in all_queries:
        query_type = query.get('Type', 'Unknown')
        if query_type not in queries_by_type:
            queries_by_type[query_type] = []
        queries_by_type[query_type].append(query)
    
    # 为每种类型生成SQL文件
    for query_type, queries in queries_by_type.items():
        
        # 为每个查询生成SQL
        for i, query in enumerate(queries):
            # 确定表名
            tables = query.get("FROM", ["nba_table"])
            if len(tables) == 1:
                table_name = tables[0]
            else:
                table_name = "NBA_Data"  # 多表时使用通用名称
            
            # 生成建表SQL
            schema_dict = query.get("SCHEMA", {})
            schema_sql = generate_schema_sql(schema_dict, table_name)
            
            # 生成查询SQL
            query_sql = generate_sql_query(query)
            
            # 计算统计信息
            select_count = len(query.get("SELECT", []))
            filter_count = len(query.get("Combination", []))
            
            # 组合完整SQL
            complete_sql = f"-- Query {i+1} ({query['Type']})\n"
            complete_sql += f"-- Total Rows: {query.get('WHERE Total Rows', 'N/A')}\n"
            complete_sql += f"-- SELECT: {select_count}\n"
            complete_sql += f"-- FILTER: {filter_count}\n"
            if len(query.get("FROM", [])) > 1:
                complete_sql += f"-- TABLES: {', '.join(query['FROM'])}\n"
            complete_sql += "\n"
            complete_sql += schema_sql + "\n\n"
            complete_sql += query_sql + "\n"
            complete_sql += "-" * 50 + "\n\n"
            
            # 保存到查询对象中
            query["SQL"] = {
                "schema": schema_sql,
                "query": query_sql,
                "complete": complete_sql
            }
        
        # 保存SQL文件
        sql_file = os.path.join(output_dir, f"{query_type}.sql")
        with open(sql_file, 'w', encoding='utf-8') as f:
            for query in queries:
                f.write(query["SQL"]["complete"])
        
        print(f"已生成SQL文件: {sql_file} ({len(queries)} 个查询)")

# 执行保存
print("开始生成NBA SQL文件...")
save_nba_queries_to_sql()
print("SQL文件生成完成！")

#### 保存所有查询结果

#### NBA Join查询


In [None]:
#### NBA Join查询生成器集成代码 - 完整修复版本

import os
import random

# NBA业务场景模板定义 - 所有模板都包含player表
NBA_TEMPLATES = [
    # 单表模板 - 球员相关
    {"name": "nationality_analysis", "description": "特定国籍球员分析", 
     "filters": ["nationality"], "primary_table": "player"},
    {"name": "age_performance", "description": "年龄与表现关系研究", 
     "filters": ["age", "nba_championships"], "primary_table": "player"},
    {"name": "draft_analysis", "description": "选秀球员分析", 
     "filters": ["draft_year", "draft_pick"], "primary_table": "player"},
    {"name": "champion_players", "description": "冠军球员研究", 
     "filters": ["nba_championships"], "primary_table": "player"},
    {"name": "mvp_study", "description": "MVP球员研究", 
     "filters": ["mvp_awards"], "primary_table": "player"},
    
    # 双表JOIN模板 - player + team
    {"name": "team_performance_players", "description": "球队表现与球员关系", 
     "filters": ["championship", "age"], "tables": ["player", "team"]},
    {"name": "superstar_teams", "description": "巨星球员所在球队", 
     "filters": ["mvp_awards", "championship"], "tables": ["player", "team"]},
    
    # 双表JOIN模板 - player + owner
    {"name": "nationality_ownership", "description": "球员国籍与老板关系", 
     "filters": ["nationality", "age"], "tables": ["player", "owner"]},
    
    # 三表JOIN模板 - player + team + owner
    {"name": "player_team_owner_dynamics", "description": "球员-球队-老板动态分析", 
     "filters": ["age", "championship", "own_year"], 
     "tables": ["player", "team", "owner"]},
    {"name": "mvp_ownership_success", "description": "MVP球员与老板成功关系", 
     "filters": ["mvp_awards", "founded_year", "nationality"], 
     "tables": ["player", "team", "owner"]},
]

# 查询类型定义
NBA_SINGLE_TYPES = ["SF", "SFW", "SFWT", "SFWG", "SFWA", "SFWGA", "SFWGAT"]
NBA_JOIN_TYPES = ["SFJ", "SFWJ", "SFWTJ", "SFWGJ", "SFWAJ", "SFWGAJ", "SFWGATJ"]

class NBASQLTemplateGenerator:
    def __init__(self, dataframes, attr_value_dict, filter_dict, join_relationships):
        self.dataframes = dataframes
        self.attr_value_dict = attr_value_dict
        self.filter_dict = filter_dict
        
        # 修复：正确定义NBA表之间的JOIN关系
        self.join_relationships = {
            ("owner", "team"): ("team", "team_name"),           # owner表的team字段 = team表的team_name字段
            ("team", "city"): ("city", "city"),                 # team表的city字段 = city表的city字段
            ("player", "team"): ("team", "team_name"),          # player表的team字段 = team表的team_name字段
            # 间接关系（通过team表连接）
            ("owner", "player"): "via_team",                    # owner -> team -> player
            ("player", "city"): "via_team"                      # player -> team -> city
        }
        
        self.min_rows = 1
        
        self.table_attr_mapping = {
            "city": ["city", "state_name", "population", "area", "gdp"],
            "owner": ["name", "age", "nationality", "team", "own_year"], 
            "team": ["team_name", "founded_year", "city", "ownership", "championship"],
            "player": ["name", "birth_date", "nationality", "age", "team", "position", 
                      "draft_pick", "draft_year", "college", "nba_championships", 
                      "mvp_awards", "olympic_gold_medals", "fiba_world_cup"]
        }
    
    def get_template_filter_value(self, table_name, attr, relaxation_level=0):
        """智能获取模板Filter值 - 直接生成条件，不验证结果"""
        if (table_name not in self.attr_value_dict or 
            attr not in self.attr_value_dict[table_name]):
            return None, None
            
        values = self.attr_value_dict[table_name][attr]
        if not values:
            return None, None
        
        # 从values中随机选择一个值
        selected_value_info = random.choice(values)
        
        # 检查值的格式并提取实际值
        if isinstance(selected_value_info, dict) and 'value' in selected_value_info:
            actual_value = selected_value_info['value']
            # 如果已经有预定义的操作符，可以直接使用
            if 'operator' in selected_value_info:
                operator = selected_value_info['operator']
                return actual_value, f"{table_name}.{attr} {operator} {actual_value}"
        else:
            # 直接的值
            actual_value = selected_value_info
        
        # 跳过空值
        if isinstance(actual_value, str) and (actual_value.strip() == '' or actual_value.strip() == ' '):
            return None, None
        
        # 数值属性处理
        if attr in numerical_attr_list:
            # 简化操作符选择
            if relaxation_level == 0:
                op = random.choice(["==", ">", ">=", "<", "<="])
            else:
                op = random.choice([">=", "<="])  # 范围查询
            
            return actual_value, f"{table_name}.{attr} {op} {actual_value}"
        else:
            # 文本属性处理
            return actual_value, f"{table_name}.{attr} = '{actual_value}'"
    
    def apply_template_filters(self, template, max_relaxation=2):
        """应用模板Filter - 无论结果多少都生成查询"""
        # 确定涉及的表
        if "primary_table" in template:
            tables_involved = [template["primary_table"]]
        elif "tables" in template:
            tables_involved = template["tables"]
        else:
            # 自动推断表
            tables_involved = []
            for filter_attr in template["filters"]:
                for table_name, attrs in self.table_attr_mapping.items():
                    if filter_attr in attrs and table_name not in tables_involved:
                        tables_involved.append(table_name)
        
        if not tables_involved:
            return None, None, "NO_VALID_TABLES"
        
        # 尝试不同放宽级别，但无论结果多少都生成
        for relaxation_level in range(max_relaxation + 1):
            conditions = []
            success = True
            
            # 尝试为每个filter生成条件
            for filter_attr in template["filters"]:
                condition_found = False
                for table_name in tables_involved:
                    if filter_attr in self.table_attr_mapping.get(table_name, []):
                        val, condition = self.get_template_filter_value(
                            table_name, filter_attr, relaxation_level
                        )
                        if condition:
                            conditions.append(condition)
                            condition_found = True
                            break
                
                if not condition_found:
                    success = False
                    break
            
            # 只要成功生成了条件就返回，不验证结果数量
            if success and len(conditions) >= 1:
                # 简单估算，不验证实际结果
                estimated_rows = 1  # 假设至少有1行结果
                return list(range(estimated_rows)), conditions, relaxation_level
        
        return None, None, "FAILED_ALL_RELAXATION"
    
    def generate_join_clause(self, tables):
        """生成正确的JOIN子句"""
        if len(tables) <= 1:
            return ""
        
        join_parts = []
        
        for i in range(1, len(tables)):
            current_table = tables[i]
            join_condition = None
            
            # 寻找当前表与之前任一表的JOIN关系
            for prev_idx in range(i):
                prev_table = tables[prev_idx]
                
                # 检查直接的JOIN关系
                for (t1, t2), (key1, key2) in self.join_relationships.items():
                    if isinstance((key1, key2), tuple):  # 直接关系
                        if (prev_table == t1 and current_table == t2):
                            join_condition = f"{prev_table}.{key1} = {current_table}.{key2}"
                            break
                        elif (prev_table == t2 and current_table == t1):
                            join_condition = f"{prev_table}.{key2} = {current_table}.{key1}"
                            break
                
                if join_condition:
                    break
            
            # 如果没有找到直接关系，处理间接关系
            if not join_condition:
                # 三表JOIN的特殊处理
                if len(tables) == 3 and "team" in tables:
                    if current_table == "player" and "team" in tables[:i+1]:
                        join_condition = "player.team = team.team_name"
                    elif current_table == "owner" and "team" in tables[:i+1]:
                        join_condition = "owner.team = team.team_name"
                    elif current_table == "city" and "team" in tables[:i+1]:
                        join_condition = "team.city = city.city"
                    elif current_table == "team":
                        if "player" in tables[:i]:
                            join_condition = "player.team = team.team_name"
                        elif "owner" in tables[:i]:
                            join_condition = "owner.team = team.team_name"
                        elif "city" in tables[:i]:
                            join_condition = "team.city = city.city"
            
            if join_condition:
                join_parts.append(f"INNER JOIN {current_table} ON {join_condition}")
            else:
                # 最后的备选方案
                print(f"Warning: Using fallback join for {current_table}")
                join_parts.append(f"INNER JOIN {current_table} ON {tables[0]}.team_name = {current_table}.team")
        
        return "\n" + "\n".join(join_parts) if join_parts else ""
    
    def create_multi_table_schema(self, tables):
        """创建多表Schema"""
        schemas = []
        for table in tables:
            if table in self.dataframes:
                attrs = self.table_attr_mapping.get(table, [])
                schema_parts = [f"    {table}_id INTEGER PRIMARY KEY"]
                
                for attr in attrs:
                    if attr in numerical_attr_list:
                        schema_parts.append(f"    {attr} FLOAT")
                    else:
                        schema_parts.append(f"    {attr} VARCHAR(255)")
                
                schema = f"CREATE TABLE {table} (\n"
                schema += ",\n".join(schema_parts)
                schema += "\n);"
                schemas.append(schema)
        
        return "\n\n".join(schemas)
    
    def generate_query_sql(self, template, query_type, query_id):
        """生成查询SQL - 总是生成查询，无论是否有结果"""
        # 应用模板Filter
        result = self.apply_template_filters(template)
        
        if result[2] in ["FAILED_ALL_RELAXATION", "NO_VALID_TABLES"]:
            # 即使失败也生成基本查询
            if "primary_table" in template:
                tables = [template["primary_table"]]
            elif "tables" in template:
                tables = template["tables"]
            else:
                tables = ["player"]
            
            # 生成无WHERE条件的基本查询
            conditions = []
            relaxation_used = "FAILED"
            
            is_join = query_type.endswith('J')
            header = f"-- Query {query_id} - {query_type}"
            header += " (JOIN)" if is_join else " (Single Table)"
            header += f"\n-- Template: {template['name']}\n"
            header += f"-- Description: {template['description']}\n"
            header += f"-- Required Filters: {len(template['filters'])}\n"
            header += f"-- Generated: Basic query without WHERE conditions\n\n"
            
            # 生成基本查询
            if is_join and len(tables) > 1:
                return self._generate_join_sql(template, query_type, query_id, tables, conditions, relaxation_used)
            else:
                return self._generate_single_sql(template, query_type, query_id, tables[0], conditions, relaxation_used)
        
        indices, conditions, relaxation_used = result
        
        # 确定涉及的表
        if "primary_table" in template:
            tables = [template["primary_table"]]
        elif "tables" in template:
            tables = template["tables"]
        else:
            tables = ["player"]  # 默认表
        
        # 判断是否生成JOIN查询
        is_join_query = query_type.endswith('J') or len(tables) > 1
        
        if is_join_query and len(tables) > 1:
            return self._generate_join_sql(template, query_type, query_id, tables, conditions, relaxation_used)
        else:
            return self._generate_single_sql(template, query_type, query_id, tables[0], conditions, relaxation_used)
    
    def _generate_join_sql(self, template, query_type, query_id, tables, conditions, relaxation_used):
        """生成JOIN SQL"""
        base_type = query_type.rstrip('J')
        
        # 选择显示列
        select_attrs = []
        for table in tables[:2]:  # 最多2个表避免SELECT过长
            table_attrs = self.table_attr_mapping.get(table, [])[:3]
            for attr in table_attrs:
                select_attrs.append(f"{table}.{attr}")
        
        select_clause = f"SELECT {', '.join(select_attrs[:5])}"
        from_clause = f"FROM {tables[0]}" + self.generate_join_clause(tables)
        where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
        
        sql_parts = [select_clause, from_clause]
        if where_clause:
            sql_parts.append(where_clause)
        
        # 根据查询类型添加特定子句
        if base_type == "SFWT":
            # TOP-K查询
            numeric_attrs = []
            for table in tables:
                for attr in self.table_attr_mapping.get(table, []):
                    if attr in numerical_attr_list:
                        numeric_attrs.append(f"{table}.{attr}")
            
            if numeric_attrs:
                order_col = random.choice(numeric_attrs)
                sql_parts.append(f"ORDER BY {order_col} {random.choice(['ASC', 'DESC'])}")
                sql_parts.append(f"LIMIT {random.choice([3, 5, 8, 10])}")
        
        elif base_type in ["SFWG", "SFWGA", "SFWGAT"]:
            # GROUP BY查询
            category_attrs = []
            for table in tables:
                for attr in self.table_attr_mapping.get(table, []):
                    if attr in category_attr_list:
                        category_attrs.append(f"{table}.{attr}")
            
            if category_attrs:
                group_col = random.choice(category_attrs)
                sql_parts.append(f"GROUP BY {group_col}")
                
                if base_type in ["SFWGA", "SFWGAT"]:
                    # 聚合查询
                    agg_func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
                    numeric_attrs = []
                    for table in tables:
                        for attr in self.table_attr_mapping.get(table, []):
                            if attr in numerical_attr_list:
                                numeric_attrs.append(f"{table}.{attr}")
                    
                    if numeric_attrs and agg_func != "COUNT":
                        agg_col = random.choice(numeric_attrs)
                        sql_parts[0] = f"SELECT {group_col}, {agg_func}({agg_col})"
                    else:
                        sql_parts[0] = f"SELECT {group_col}, COUNT(*)"
                    
                    if base_type == "SFWGAT":
                        sql_parts.append(f"ORDER BY COUNT(*) DESC")
                        sql_parts.append(f"LIMIT {random.choice([3, 5, 8])}")
        
        elif base_type == "SFWA":
            # 纯聚合查询
            agg_func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
            numeric_attrs = []
            for table in tables:
                for attr in self.table_attr_mapping.get(table, []):
                    if attr in numerical_attr_list:
                        numeric_attrs.append(f"{table}.{attr}")
            
            if numeric_attrs and agg_func != "COUNT":
                agg_col = random.choice(numeric_attrs)
                sql_parts[0] = f"SELECT {agg_func}({agg_col})"
            else:
                sql_parts[0] = f"SELECT COUNT(*)"
        
        # 生成完整SQL
        query_sql = "\n".join(sql_parts) + ";"
        schema_sql = self.create_multi_table_schema(tables)
        
        header = f"-- Query {query_id} - {query_type} (JOIN Query)\n"
        header += f"-- Template: {template['name']}\n"
        header += f"-- Description: {template['description']}\n"
        header += f"-- Tables: {', '.join(tables)}\n"
        if conditions:
            header += f"-- Filters: {len(conditions)}/{len(template['filters'])} (using {len(conditions)} filters)"
            if relaxation_used != "FAILED" and relaxation_used > 0:
                header += f" (relaxed {relaxation_used} levels)"
        else:
            header += f"-- Filters: 0/{len(template['filters'])} (no WHERE conditions)"
        header += "\n\n"
        
        return header + schema_sql + "\n\n" + query_sql + "\n\n" + "-" * 60 + "\n\n"
    
    def _generate_single_sql(self, template, query_type, query_id, table_name, conditions, relaxation_used):
        """生成单表SQL"""
        table_attrs = self.table_attr_mapping.get(table_name, [])
        select_attrs = random.sample(table_attrs, min(4, len(table_attrs)))
        
        select_clause = f"SELECT {', '.join(select_attrs)}"
        from_clause = f"FROM {table_name}"
        where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
        
        sql_parts = [select_clause, from_clause]
        if where_clause:
            sql_parts.append(where_clause)
        
        # 根据查询类型添加子句
        if query_type == "SFWT":
            numeric_attrs = [attr for attr in table_attrs if attr in numerical_attr_list]
            if numeric_attrs:
                order_col = random.choice(numeric_attrs)
                sql_parts.append(f"ORDER BY {order_col} {random.choice(['ASC', 'DESC'])}")
                sql_parts.append(f"LIMIT {random.choice([3, 5, 8, 10])}")
        
        elif query_type in ["SFWG", "SFWGA", "SFWGAT"]:
            category_attrs = [attr for attr in table_attrs if attr in category_attr_list]
            if category_attrs:
                group_col = random.choice(category_attrs)
                sql_parts.append(f"GROUP BY {group_col}")
                
                if query_type in ["SFWGA", "SFWGAT"]:
                    agg_func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
                    numeric_attrs = [attr for attr in table_attrs if attr in numerical_attr_list]
                    
                    if numeric_attrs and agg_func != "COUNT":
                        agg_col = random.choice(numeric_attrs)
                        sql_parts[0] = f"SELECT {group_col}, {agg_func}({agg_col})"
                    else:
                        sql_parts[0] = f"SELECT {group_col}, COUNT(*)"
                    
                    if query_type == "SFWGAT":
                        sql_parts.append(f"ORDER BY COUNT(*) DESC")
                        sql_parts.append(f"LIMIT {random.choice([3, 5, 8])}")
        
        elif query_type == "SFWA":
            agg_func = random.choice(["COUNT", "MAX", "MIN", "AVG", "SUM"])
            numeric_attrs = [attr for attr in table_attrs if attr in numerical_attr_list]
            
            if numeric_attrs and agg_func != "COUNT":
                agg_col = random.choice(numeric_attrs)
                sql_parts[0] = f"SELECT {agg_func}({agg_col})"
            else:
                sql_parts[0] = f"SELECT COUNT(*)"
        
        # 生成Schema
        schema_parts = []
        for attr in table_attrs:
            if attr in numerical_attr_list:
                schema_parts.append(f"    {attr} FLOAT")
            else:
                schema_parts.append(f"    {attr} VARCHAR(255)")
        
        schema_sql = f"CREATE TABLE {table_name} (\n" + ",\n".join(schema_parts) + "\n);"
        query_sql = "\n".join(sql_parts) + ";"
        
        header = f"-- Query {query_id} - {query_type} (Single Table)\n"
        header += f"-- Template: {template['name']}\n"
        header += f"-- Description: {template['description']}\n"
        header += f"-- Table: {table_name}\n"
        if conditions:
            header += f"-- Filters: {len(conditions)}/{len(template['filters'])} (using {len(conditions)} filters)"
            if relaxation_used != "FAILED" and relaxation_used > 0:
                header += f" (relaxed {relaxation_used} levels)"
        else:
            header += f"-- Filters: 0/{len(template['filters'])} (no WHERE conditions)"
        header += "\n\n"
        
        return header + schema_sql + "\n\n" + query_sql + "\n\n" + "-" * 60 + "\n\n"

def generate_nba_template_queries():
    """生成NBA模板查询"""
    
    # 创建输出目录
    base_dir = "./NBA_Template_Queries/"
    os.makedirs(base_dir, exist_ok=True)
    
    # 创建生成器
    generator = NBASQLTemplateGenerator(dataframes, attr_value_dict, filter_dict, join_relationships)
    
    queries_per_type = 3  # 每个类型生成3个查询，提高成功率
    total_generated = 0
    total_failed = 0
    
    # 为每个模板生成查询
    for template in NBA_TEMPLATES:
        template_dir = os.path.join(base_dir, template["name"])
        os.makedirs(template_dir, exist_ok=True)
        
        # 确定查询类型
        is_multi_table = len(template.get("tables", [])) > 1
        query_types = NBA_SINGLE_TYPES + (NBA_JOIN_TYPES if is_multi_table else [])
        
        template_stats = {}
        
        for qtype in query_types:
            sql_content = []
            query_id = 1
            generated = 0
            attempts = 0
            
            while generated < queries_per_type and attempts < queries_per_type * 2:  # 减少尝试次数
                attempts += 1
                sql = generator.generate_query_sql(template, qtype, query_id)
                
                if sql:
                    sql_content.append(sql)
                    # 现在总是算作生成成功，不管是否有FAILED标记
                    generated += 1
                    total_generated += 1
                    query_id += 1
            
            template_stats[qtype] = generated
            
            # 保存文件
            filename = os.path.join(template_dir, f"{qtype}.sql")
            with open(filename, 'w', encoding='utf-8') as f:
                f.write(f"-- NBA {template['description']} - {qtype} 查询\n")
                f.write(f"-- 模板: {template['name']}\n")
                f.write(f"-- Filter数量: {len(template['filters'])}\n")
                f.write(f"-- 涉及表: {template.get('tables', [template.get('primary_table', 'auto')])}\n")
                f.write("-- " + "=" * 60 + "\n\n")
                f.write("".join(sql_content))
    
    # 生成汇总
    summary = f"""NBA模板查询生成汇总
{"="*50}

总模板数: {len(NBA_TEMPLATES)}
成功生成: {total_generated} 个查询
生成失败: {total_failed} 个查询
成功率: {total_generated/(total_generated+total_failed)*100:.1f}% (如果有查询的话)

模板类型分布:
- 单表模板: {len([t for t in NBA_TEMPLATES if 'primary_table' in t])} 个
- 多表模板: {len([t for t in NBA_TEMPLATES if 'tables' in t])} 个

查询类型: 
- 单表查询类型: {len(NBA_SINGLE_TYPES)} 种
- JOIN查询类型: {len(NBA_JOIN_TYPES)} 种

输出目录: {base_dir}
"""
    
    with open(os.path.join(base_dir, "SUMMARY.txt"), 'w', encoding='utf-8') as f:
        f.write(summary)
    
    return total_generated

# 执行生成
try:
    total_queries = generate_nba_template_queries()
    print(f"✅ NBA模板查询生成完成: {total_queries} 个查询")
    print(f"📁 文件保存在: ./NBA_Template_Queries/")
    
except Exception as e:
    print(f"❌ 生成错误: {e}")
    import traceback
    traceback.print_exc()