In [None]:
# xgboost + wgan-gp
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import xgboost as xgb
from sklearn.model_selection import KFold, RandomizedSearchCV, train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import os
import joblib
from pathlib import Path 
CURRENT_DIR = Path.cwd()
PROJECT_ROOT = CURRENT_DIR.parent
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "output"

# --- 0. 全局配置 ---
# 数据路径
BASE_DATA_PATH = DATA_DIR
DEV_SET_FILE = DATA_DIR / "development_set_selected_features.xlsx"
TEST_SET_FILE = DATA_DIR / "final_test_set_selected_features.xlsx"
AUGMENTED_DATA_OUTPUT_FOLDER =  DATA_DIR / "augmented_outputs" # 保存中间增强数据
MODEL_OUTPUT_PATH = DATA_DIR / "trained_models" # 保存训练好的模型
OUTPUT_PLOT_PATH = OUTPUT_DIR # 图表导出路径

os.makedirs(AUGMENTED_DATA_OUTPUT_FOLDER, exist_ok=True)
os.makedirs(MODEL_OUTPUT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PLOT_PATH, exist_ok=True) # 确保图表路径存在

TARGET_COLUMN = 'Rowing distance'
RANDOM_STATE = 42
N_SPLITS_KFOLD = 5 # K-Fold的折数

# WGAN-GP 预设参数
# CHO_LEVELS = np.array([0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
DEFAULT_WGAN_PARAMS = {
    'latent_dim': 100,
    'lambda_gp': 10,
    'n_critic': 5,
    'lr': 0.00005,
    'batch_size': 32,
    'epochs_for_cv': 500, # K-Fold内部动态增强时使用的轮数
    'epochs_for_final': 3000 # 用于最终增强开发集或HPO时增强开发集的轮数
}

# XGBoost RandomizedSearchCV 参数网格
XGB_PARAM_GRID = {
    'n_estimators': [100, 200, 300, 400],
    'max_depth': [3, 5, 7],
    'learning_rate': [0.01, 0.05, 0.1],
    'subsample': [0.7, 0.8, 0.9, 1],
    'colsample_bytree': [0.7, 0.8, 0.9, 1],
    'gamma': [0, 0.1, 0.2],
    'reg_alpha': [0, 0.01, 0.1],
    'reg_lambda': [0.5, 1, 1.5]
}
N_ITER_RANDOMIZED_SEARCH = 30 # RandomizedSearchCV的迭代次数

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"将使用设备: {device}")

# --- 1. WGAN-GP 模型定义 (Generator, Critic, gradient_penalty) ---
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, output_dim)
        )
    def forward(self, z):
        return self.model(z)

class Critic(nn.Module):
    def __init__(self, input_dim):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.model(x)

def gradient_penalty(critic_model, real_samples, fake_samples, device_in_use):
    batch_size_gp = real_samples.size(0)
    alpha = torch.rand(batch_size_gp, 1, device=device_in_use).expand_as(real_samples)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = critic_model(interpolates)
    fake_grad_output = torch.ones_like(d_interpolates, device=device_in_use)
    gradients = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates, grad_outputs=fake_grad_output,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty_val = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty_val

# --- 2. WGAN-GP 训练与生成辅助函数 ---
def train_and_generate_wgangp(input_original_df,
                              target_col_name, # 实际在GAN中未特殊处理，但用于列名和后处理一致性
                              wgan_hyperparams,
                              num_samples_to_generate,
                              current_device,
                              fold_num_for_logging=None,
                              output_augmented_data_path=None):

    log_prefix = f"[WGAN-GP"
    if isinstance(fold_num_for_logging, int):
        log_prefix += f" Fold {fold_num_for_logging}"
    elif isinstance(fold_num_for_logging, str): # For HPO or Final stages
        log_prefix += f" Stage {fold_num_for_logging}"
    log_prefix += "]"

    print(f"\n{log_prefix} 开始处理，输入数据形状: {input_original_df.shape}")

    all_feature_names = input_original_df.columns.tolist()
    original_data_values = input_original_df.values.astype(np.float32)

    # (A) 标准化所有特征 (基于当前输入数据)
    data_mean = np.mean(original_data_values, axis=0)
    data_std = np.std(original_data_values, axis=0)
    data_std[data_std == 0] = 1 # 防止标准差为0
    standardized_data = (original_data_values - data_mean) / data_std

    # (B) 创建DataLoader
    data_tensor = torch.tensor(standardized_data)
    dataset = TensorDataset(data_tensor)
    
    current_batch_size = min(wgan_hyperparams['batch_size'], len(dataset))
    if current_batch_size == 0:
        print(f"{log_prefix} 错误：数据集为空或过小 ({len(dataset)} samples), 无法创建DataLoader。")
        return pd.DataFrame(columns=all_feature_names) # 返回空DataFrame
    
    use_drop_last = len(dataset) >= current_batch_size * 2 # 至少需要两个batch才能考虑drop_last
    dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True, drop_last=use_drop_last)
    
    if len(dataloader) == 0 and len(dataset) > 0 : # 再次检查，如果dataloader为空但dataset不为空
        print(f"{log_prefix} 警告: DataLoader为空，但输入数据集不为空。可能是batch_size ({current_batch_size}) 大于样本数 ({len(dataset)})。将尝试 drop_last=False")
        dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True, drop_last=False)
        if len(dataloader) == 0 and len(dataset) > 0:
             print(f"{log_prefix} 错误: DataLoader仍然为空。无法继续GAN训练。")
             return pd.DataFrame(columns=all_feature_names)

    # 初始化模型、优化器
    num_features = standardized_data.shape[1]
    generator = Generator(wgan_hyperparams['latent_dim'], num_features).to(current_device)
    critic = Critic(num_features).to(current_device)
    optimizer_G = optim.Adam(generator.parameters(), lr=wgan_hyperparams['lr'], betas=(0.5, 0.9))
    optimizer_C = optim.Adam(critic.parameters(), lr=wgan_hyperparams['lr'], betas=(0.5, 0.9))

    print(f"{log_prefix} 开始WGAN-GP训练 ({wgan_hyperparams['epochs']} 轮)...")
    critic_loss_val, generator_loss_val = torch.tensor(0.0), torch.tensor(0.0) # Default values
    for epoch in range(wgan_hyperparams['epochs']):
        for i, (real_samples_batch,) in enumerate(dataloader):
            if real_samples_batch.shape[0] == 0: continue
            real_samples_batch = real_samples_batch.to(current_device)
            current_real_batch_size = real_samples_batch.size(0) # 使用实际的batch size

            for _ in range(wgan_hyperparams['n_critic']):
                optimizer_C.zero_grad()
                z = torch.randn(current_real_batch_size, wgan_hyperparams['latent_dim'], device=current_device)
                fake_samples_batch = generator(z)
                critic_real = critic(real_samples_batch)
                critic_fake = critic(fake_samples_batch.detach())
                gp = gradient_penalty(critic, real_samples_batch, fake_samples_batch, current_device)
                critic_loss = torch.mean(critic_fake) - torch.mean(critic_real) + wgan_hyperparams['lambda_gp'] * gp
                critic_loss.backward()
                optimizer_C.step()
                critic_loss_val = critic_loss # Store for printing
            
            optimizer_G.zero_grad()
            z = torch.randn(current_real_batch_size, wgan_hyperparams['latent_dim'], device=current_device)
            generated_for_g_loss = generator(z)
            generator_loss = -torch.mean(critic(generated_for_g_loss))
            generator_loss.backward()
            optimizer_G.step()
            generator_loss_val = generator_loss # Store for printing

        if (epoch + 1) % (max(1, wgan_hyperparams['epochs'] // 10)) == 0: # 每10%轮数打印一次
             print(f"{log_prefix} [Epoch {epoch+1}/{wgan_hyperparams['epochs']}] Critic Loss: {critic_loss_val.item():.4f}, Gen Loss: {generator_loss_val.item():.4f}")

    print(f"{log_prefix} WGAN-GP训练完成。")

    print(f"{log_prefix} 正在生成 {num_samples_to_generate} 个增强样本...")
    generator.eval()
    generated_samples_list = []
    remaining_samples = num_samples_to_generate
    gen_batch_size = wgan_hyperparams['batch_size'] # Use GAN's batch size for generation efficiency
    
    with torch.no_grad():
        while remaining_samples > 0:
            current_gen_size = min(gen_batch_size, remaining_samples)
            z_generate = torch.randn(current_gen_size, wgan_hyperparams['latent_dim'], device=current_device)
            generated_batch_std = generator(z_generate).detach().cpu().numpy()
            generated_samples_list.append(generated_batch_std)
            remaining_samples -= current_gen_size
            
    generated_standardized_data_np = np.concatenate(generated_samples_list, axis=0)


    # 反标准化
    generated_data_original_scale_np = generated_standardized_data_np * data_std + data_mean
    generated_data_df = pd.DataFrame(generated_data_original_scale_np, columns=all_feature_names)

    # 后处理 ('CHO' 和 np.clip)
    # if 'CHO' in generated_data_df.columns:
    #     cho_column_generated = generated_data_df['CHO'].values
    #     processed_cho = np.array([CHO_LEVELS[np.abs(CHO_LEVELS - val).argmin()] for val in cho_column_generated])
    #     generated_data_df['CHO'] = processed_cho

    for col_name in all_feature_names:
        # if col_name == 'CHO': continue # <--- 删除或注释掉这一行，让CHO也参与通用裁剪
        original_col_values = input_original_df[col_name]
        col_min_original = original_col_values.min()
        col_max_original = original_col_values.max()
        col_range = col_max_original - col_min_original
    
        # 这部分原有的1%范围扩展逻辑是好的，保持不变
        clip_min_for_col = col_min_original - 0.01 * col_range if col_range != 0 else col_min_original
        clip_max_for_col = col_max_original + 0.01 * col_range if col_range != 0 else col_max_original
    
        # 在这里添加针对 CHO 和 PRO（以及其他您认为需要非负的列）的特殊处理
        if col_name in ['CHO', 'PRO']: # 如果有其他列也需要确保非负，可以加入此列表
            clip_min_for_col = max(0, clip_min_for_col)
    
        generated_data_df[col_name] = np.clip(generated_data_df[col_name], clip_min_for_col, clip_max_for_col)
    print(f"{log_prefix} 后处理完成。")

    if output_augmented_data_path:
        try:
            os.makedirs(os.path.dirname(output_augmented_data_path), exist_ok=True)
            generated_data_df.to_excel(output_augmented_data_path, index=False)
            print(f"{log_prefix} 增强数据已保存到: {output_augmented_data_path}")
        except Exception as e:
            print(f"{log_prefix} 保存增强数据时发生错误: {e}")
            
    return generated_data_df

# --- 3. 主流程开始 ---
try:
    development_df_original = pd.read_excel(DEV_SET_FILE)
    final_test_df_original = pd.read_excel(TEST_SET_FILE)
    print(f"开发集形状: {development_df_original.shape}, 最终测试集形状: {final_test_df_original.shape}")
except FileNotFoundError as e:
    print(f"错误: 开发集或测试集文件未找到。请检查路径: {e}")
    exit()

X_dev_original = development_df_original.drop(columns=[TARGET_COLUMN])
y_dev_original = development_df_original[TARGET_COLUMN]
X_final_test = final_test_df_original.drop(columns=[TARGET_COLUMN])
y_final_test = final_test_df_original[TARGET_COLUMN]

# --- 步骤三：Part 1 - 为XGBoost确定最佳超参数 ---
print("\n--- 步骤三：Part 1 - XGBoost 超参数调优 ---")
print("为超参数调优生成开发集的增强版本...")
num_augmented_samples_for_hpo = len(development_df_original) * 1
current_wgan_hpo_params = DEFAULT_WGAN_PARAMS.copy()
current_wgan_hpo_params['epochs'] = DEFAULT_WGAN_PARAMS['epochs_for_final']
augmented_dev_for_hpo_output_path = os.path.join(AUGMENTED_DATA_OUTPUT_FOLDER, "augmented_dev_for_hpo.xlsx")

augmented_dev_for_hpo_df = train_and_generate_wgangp(
    input_original_df=development_df_original.copy(),
    target_col_name=TARGET_COLUMN,
    wgan_hyperparams=current_wgan_hpo_params,
    num_samples_to_generate=num_augmented_samples_for_hpo,
    current_device=device,
    fold_num_for_logging="HPO_Dev_Set",
    output_augmented_data_path=augmented_dev_for_hpo_output_path
)
if augmented_dev_for_hpo_df.empty:
    print("错误：为HPO生成的增强数据为空，无法继续进行超参数调优。请检查WGAN-GP流程。")
    exit()

X_augmented_dev_for_hpo = augmented_dev_for_hpo_df.drop(columns=[TARGET_COLUMN])
y_augmented_dev_for_hpo = augmented_dev_for_hpo_df[TARGET_COLUMN]
X_combined_dev_for_hpo = pd.concat([X_dev_original, X_augmented_dev_for_hpo], ignore_index=True)
y_combined_dev_for_hpo = pd.concat([y_dev_original, y_augmented_dev_for_hpo], ignore_index=True)
print(f"用于HPO的总数据形状: {X_combined_dev_for_hpo.shape}")

xgb_regressor_for_hpo = xgb.XGBRegressor(objective='reg:squarederror', random_state=RANDOM_STATE, tree_method='gpu_hist' if device.type == 'cuda' else 'hist')
random_search_hpo = RandomizedSearchCV(
    estimator=xgb_regressor_for_hpo, param_distributions=XGB_PARAM_GRID,
    n_iter=N_ITER_RANDOMIZED_SEARCH, cv=N_SPLITS_KFOLD,
    scoring='neg_mean_absolute_error', verbose=1, random_state=RANDOM_STATE, n_jobs=-1
)
print("开始XGBoost超参数搜索 (RandomizedSearchCV)...")
random_search_hpo.fit(X_combined_dev_for_hpo, y_combined_dev_for_hpo)
best_overall_xgboost_params = random_search_hpo.best_params_
print(f"找到的最佳XGBoost超参数: {best_overall_xgboost_params}")
print(f"最佳HPO MAE (负值): {random_search_hpo.best_score_}")

# --- 步骤三：Part 2 - K-折交叉验证与动态增强 ---
print(f"\n--- 步骤三：Part 2 - 在开发集上进行 {N_SPLITS_KFOLD}-折交叉验证 (动态WGAN-GP增强) ---")
kf = KFold(n_splits=N_SPLITS_KFOLD, shuffle=True, random_state=RANDOM_STATE)

kfold_cv_val_metrics_list = []
kfold_cv_train_metrics_list = []
cv_train_mae_curves = []
cv_val_mae_curves = []

current_wgan_cv_params = DEFAULT_WGAN_PARAMS.copy()
current_wgan_cv_params['epochs'] = DEFAULT_WGAN_PARAMS['epochs_for_cv']
augmentation_factor_cv = 1

for fold_idx, (train_indices, val_indices) in enumerate(kf.split(development_df_original)):
    print(f"\n--- K-Fold: 第 {fold_idx + 1}/{N_SPLITS_KFOLD} 折 ---")
    cv_train_original_fold_df = development_df_original.iloc[train_indices]
    cv_val_original_fold_df = development_df_original.iloc[val_indices]

    X_cv_val_fold = cv_val_original_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_val_fold = cv_val_original_fold_df[TARGET_COLUMN]

    print(f"当前CV训练集（原始）形状: {cv_train_original_fold_df.shape}")
    num_augmented_samples_cv_fold = len(cv_train_original_fold_df) * augmentation_factor_cv
    dynamic_wgan_cv_params = current_wgan_cv_params.copy()
    dynamic_wgan_cv_params['batch_size'] = min(current_wgan_cv_params['batch_size'], max(1, len(cv_train_original_fold_df) // 2 if len(cv_train_original_fold_df) > 1 else 1) )
    
    cv_augmented_fold_df = train_and_generate_wgangp(
        input_original_df=cv_train_original_fold_df.copy(),
        target_col_name=TARGET_COLUMN,
        wgan_hyperparams=dynamic_wgan_cv_params,
        num_samples_to_generate=num_augmented_samples_cv_fold,
        current_device=device,
        fold_num_for_logging=(fold_idx + 1),
        output_augmented_data_path=None # CV内部一般不保存
    )
    if cv_augmented_fold_df.empty:
        print(f"警告：Fold {fold_idx + 1} 的增强数据为空，跳过此折。")
        cv_train_mae_curves.append([]) # 添加空列表以保持长度一致性
        cv_val_mae_curves.append([])
        continue

    X_cv_train_original_fold = cv_train_original_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_train_original_fold = cv_train_original_fold_df[TARGET_COLUMN]
    X_cv_augmented_fold = cv_augmented_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_augmented_fold = cv_augmented_fold_df[TARGET_COLUMN]

    X_cv_train_combined_fold = pd.concat([X_cv_train_original_fold, X_cv_augmented_fold], ignore_index=True)
    y_cv_train_combined_fold = pd.concat([y_cv_train_original_fold, y_cv_augmented_fold], ignore_index=True)
    print(f"当前CV训练集（原始+增强后）形状: {X_cv_train_combined_fold.shape}")

    model_fold = xgb.XGBRegressor(
        **best_overall_xgboost_params, objective='reg:squarederror',
        random_state=RANDOM_STATE, tree_method='gpu_hist' if device.type == 'cuda' else 'hist'
    )
    
    eval_set_fold = [(X_cv_train_combined_fold, y_cv_train_combined_fold), (X_cv_val_fold, y_cv_val_fold)]
    
    print(f"Fold {fold_idx + 1}: 开始训练XGBoost模型...")
    model_fold.fit(X_cv_train_combined_fold, y_cv_train_combined_fold,
                   eval_metric='mae', eval_set=eval_set_fold,
                   early_stopping_rounds=10, verbose=False)

    fold_eval_results = model_fold.evals_result()
    cv_train_mae_curves.append(fold_eval_results['validation_0']['mae'])
    cv_val_mae_curves.append(fold_eval_results['validation_1']['mae'])

    y_pred_val_fold = model_fold.predict(X_cv_val_fold)
    mae_val = mean_absolute_error(y_cv_val_fold, y_pred_val_fold)
    rmse_val = np.sqrt(mean_squared_error(y_cv_val_fold, y_pred_val_fold))
    r2_val = r2_score(y_cv_val_fold, y_pred_val_fold)
    kfold_cv_val_metrics_list.append({'fold': fold_idx + 1, 'MAE': mae_val, 'RMSE': rmse_val, 'R2': r2_val})

    y_pred_train_fold = model_fold.predict(X_cv_train_combined_fold)
    mae_train = mean_absolute_error(y_cv_train_combined_fold, y_pred_train_fold)
    rmse_train = np.sqrt(mean_squared_error(y_cv_train_combined_fold, y_pred_train_fold))
    r2_train = r2_score(y_cv_train_combined_fold, y_pred_train_fold)
    kfold_cv_train_metrics_list.append({'fold': fold_idx + 1, 'MAE': mae_train, 'RMSE': rmse_train, 'R2': r2_train})
    
    print(f"Fold {fold_idx + 1} - CV Train MAE: {mae_train:.4f}, R2: {r2_train:.4f} | CV Val MAE: {mae_val:.4f}, R2: {r2_val:.4f}")

avg_kfold_cv_val_metrics_df = pd.DataFrame(kfold_cv_val_metrics_list)
print("\n--- K-折交叉验证平均CV验证性能 (开发集, WGAN-GP动态增强) ---")
if not avg_kfold_cv_val_metrics_df.empty:
    avg_mae_cv_val = avg_kfold_cv_val_metrics_df['MAE'].mean()
    avg_rmse_cv_val = avg_kfold_cv_val_metrics_df['RMSE'].mean()
    avg_r2_cv_val = avg_kfold_cv_val_metrics_df['R2'].mean()
    print(f"平均 CV 验证集 MAE: {avg_mae_cv_val:.4f}")
    print(f"平均 CV 验证集 RMSE: {avg_rmse_cv_val:.4f}")
    print(f"平均 CV 验证集 R2: {avg_r2_cv_val:.4f}")
else:
    print("K-Fold CV验证结果为空，无法计算平均值。")
    avg_mae_cv_val, avg_r2_cv_val = np.nan, np.nan

avg_kfold_cv_train_metrics_df = pd.DataFrame(kfold_cv_train_metrics_list)
print("\n--- K-折交叉验证平均CV训练性能 (开发集, WGAN-GP动态增强) ---")
if not avg_kfold_cv_train_metrics_df.empty:
    avg_mae_cv_train = avg_kfold_cv_train_metrics_df['MAE'].mean()
    avg_rmse_cv_train = avg_kfold_cv_train_metrics_df['RMSE'].mean()
    avg_r2_cv_train = avg_kfold_cv_train_metrics_df['R2'].mean()
    print(f"平均 CV 训练集 MAE: {avg_mae_cv_train:.4f}")
    print(f"平均 CV 训练集 R2: {avg_r2_cv_train:.4f}")
else:
    print("K-Fold CV训练结果为空，无法计算平均值。")
    avg_mae_cv_train, avg_r2_cv_train = np.nan, np.nan

if not np.isnan(avg_mae_cv_val) and not np.isnan(avg_mae_cv_train) :
    print("\n--- 步骤三 Part 2 结束：生成K-Fold性能图表 ---")
    metrics_plot_names_en = ['MAE', 'R2 Score']
    values_cv_val_plot = [avg_mae_cv_val, avg_r2_cv_val]
    values_cv_train_plot = [avg_mae_cv_train, avg_r2_cv_train]
    x_axis_plot = np.arange(len(metrics_plot_names_en))
    plt.figure(figsize=(10, 6))
    plt.bar(x_axis_plot - 0.2, values_cv_train_plot, width=0.4, label='CV Train Avg.', align='center')
    plt.bar(x_axis_plot + 0.2, values_cv_val_plot, width=0.4, label='CV Validation Avg.', align='center')
    plt.xticks(x_axis_plot, metrics_plot_names_en)
    plt.ylabel('Score')
    plt.title('Average K-Fold CV Train vs. CV Validation Metrics (WGAN-GP Augmented)')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plot_filename_metrics = os.path.join(OUTPUT_PLOT_PATH, "kfold_avg_eval_metrics_wgangp.png")
    try:
        plt.savefig(plot_filename_metrics, dpi=300, bbox_inches='tight')
        print(f"图表1已保存到: {plot_filename_metrics}")
    except Exception as e:
        print(f"保存图表1时发生错误: {e}")
    plt.show()

    if cv_train_mae_curves and cv_val_mae_curves and \
       any(cv_train_mae_curves) and any(cv_val_mae_curves): # 确保有曲线数据
        plt.figure(figsize=(14, 7))
        for i in range(len(cv_train_mae_curves)):
            if cv_train_mae_curves[i] and cv_val_mae_curves[i]:
                plt.plot(cv_train_mae_curves[i], label=f'CV Train Fold {i+1}', linestyle='-')
                plt.plot(cv_val_mae_curves[i], label=f'CV Validation Fold {i+1}', linestyle='--')
        plt.xlabel('Boosting Round')
        plt.ylabel('Mean Absolute Error (MAE)')
        plt.title('Training and CV Validation MAE per Fold (WGAN-GP Augmented)')
        if len(cv_train_mae_curves) <= 5 : plt.legend()
        else: print("图例条目过多（超过5折），未在图上显示以保持清晰。")
        plt.grid(True, linestyle='--', alpha=0.7)
        plot_filename_loss = os.path.join(OUTPUT_PLOT_PATH, "kfold_mae_loss_per_fold_wgangp.png")
        try:
            plt.savefig(plot_filename_loss, dpi=300, bbox_inches='tight')
            print(f"图表2已保存到: {plot_filename_loss}")
        except Exception as e:
            print(f"保存图表2时发生错误: {e}")
        plt.show()
    else:
        print("没有收集到足够的MAE曲线数据用于绘制图表2。")
else:
    print("由于K-Fold平均结果为空或NaN，跳过图表绘制。")

# --- 步骤四：训练最终模型 ---
print("\n--- 步骤四：训练最终模型 ---")
print("为最终模型生成开发集的完整增强版本...")
final_wgan_params = DEFAULT_WGAN_PARAMS.copy()
final_wgan_params['epochs'] = DEFAULT_WGAN_PARAMS['epochs_for_final']
num_augmented_samples_final = len(development_df_original) * 2
final_augmented_dev_output_path = os.path.join(AUGMENTED_DATA_OUTPUT_FOLDER, "augmented_dev_for_final_model.xlsx")

final_augmented_dev_df = train_and_generate_wgangp(
    input_original_df=development_df_original.copy(),
    target_col_name=TARGET_COLUMN,
    wgan_hyperparams=final_wgan_params,
    num_samples_to_generate=num_augmented_samples_final,
    current_device=device,
    fold_num_for_logging="Final_Dev_Set_Augmentation",
    output_augmented_data_path=final_augmented_dev_output_path
)
if final_augmented_dev_df.empty:
    print("错误：为最终模型生成的增强数据为空，无法继续。请检查WGAN-GP流程。")
    exit()

X_final_augmented_dev = final_augmented_dev_df.drop(columns=[TARGET_COLUMN])
y_final_augmented_dev = final_augmented_dev_df[TARGET_COLUMN]
X_train_final_model = pd.concat([X_dev_original, X_final_augmented_dev], ignore_index=True)
y_train_final_model = pd.concat([y_dev_original, y_final_augmented_dev], ignore_index=True)
print(f"用于训练最终模型的总数据形状: {X_train_final_model.shape}")

final_model = xgb.XGBRegressor(
    **best_overall_xgboost_params, objective='reg:squarederror',
    random_state=RANDOM_STATE, tree_method='gpu_hist' if device.type == 'cuda' else 'hist'
)
print("开始训练最终模型...")
# 为了在最终模型上也能使用early_stopping，并观察其学习曲线，我们可以从最终训练集中划分一小部分作为临时验证集
# 这不会影响最终测试集的纯洁性
if len(X_train_final_model) > 10 : # 确保有足够数据划分
    X_fm_train, X_fm_val, y_fm_train, y_fm_val = train_test_split(
        X_train_final_model, y_train_final_model, test_size=0.1, random_state=RANDOM_STATE
    )
    eval_set_final = [(X_fm_train, y_fm_train), (X_fm_val, y_fm_val)]
    final_model.fit(X_fm_train, y_fm_train,
                    eval_metric='mae', eval_set=eval_set_final,
                    early_stopping_rounds=10, verbose=False)
    # 可选: 在早停后，用全部数据再训练一次，使用找到的最佳迭代次数 (如果需要)
    # print(f"最终模型早停轮数: {final_model.best_iteration}")
    # final_model.fit(X_train_final_model, y_train_final_model, xgb_model=final_model.get_booster()) # 这种方式可能不直接支持
    # 或者重新初始化并训练到 best_iteration
    # final_model = xgb.XGBRegressor(**best_overall_xgboost_params, n_estimators=final_model.best_iteration, ...)
    # final_model.fit(X_train_final_model, y_train_final_model)

else: # 数据太少，不使用早停的验证集
    final_model.fit(X_train_final_model, y_train_final_model)


print("最终模型训练完成。")
final_model_path = os.path.join(MODEL_OUTPUT_PATH, "final_xgboost_wgangp_model.joblib")
joblib.dump(final_model, final_model_path)
print(f"最终模型已保存到: {final_model_path}")

print("\n--- 步骤四结束：生成最终模型特征重要性图 ---")
feature_importances = final_model.feature_importances_
# 确保特征名称来自原始开发集（因为增强数据列名和顺序应一致）
importance_df = pd.DataFrame({'Feature': X_dev_original.columns, 'Importance': feature_importances})
importance_df = importance_df.sort_values(by='Importance', ascending=False)

num_features_to_plot = min(20, len(X_dev_original.columns))
plot_height = max(6, num_features_to_plot * 0.4)
plt.figure(figsize=(10, plot_height))
plt.barh(importance_df['Feature'][:num_features_to_plot], importance_df['Importance'][:num_features_to_plot])
plt.xlabel('Feature Importance')
plt.ylabel('Feature')
plt.title(f'Top {num_features_to_plot} Feature Importances (Final Model with WGAN-GP Augmentation)')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.grid(True, axis='x', linestyle='--', alpha=0.7)
plot_filename_importance = os.path.join(OUTPUT_PLOT_PATH, "final_model_feature_importances_wgangp.png")
try:
    plt.savefig(plot_filename_importance, dpi=300, bbox_inches='tight')
    print(f"最终模型特征重要性图已保存到: {plot_filename_importance}")
except Exception as e:
    print(f"保存最终模型特征重要性图时发生错误: {e}")
plt.show()

# --- 步骤五：最终无偏评估 (在“最终测试集”上) ---
print("\n--- 步骤五：在最终测试集上进行无偏评估 ---")
# final_model = joblib.load(final_model_path) # 如果脚本分步执行，则取消此行注释
y_pred_final_test = final_model.predict(X_final_test)
mae_final = mean_absolute_error(y_final_test, y_pred_final_test)
rmse_final = np.sqrt(mean_squared_error(y_final_test, y_pred_final_test))
r2_final = r2_score(y_final_test, y_pred_final_test)
print("--- 最终模型在最终测试集上的性能 ---")
print(f"MAE: {mae_final:.4f}")
print(f"RMSE: {rmse_final:.4f}")
print(f"R2 Score: {r2_final:.4f}")

print("\n--- 步骤五结束：生成最终测试集真实值 vs 预测值图 ---")
plt.figure(figsize=(8, 8))
plt.scatter(y_final_test, y_pred_final_test, alpha=0.7, edgecolors='w', linewidth=0.5)
min_val = min(y_final_test.min(), y_pred_final_test.min())
max_val = max(y_final_test.max(), y_pred_final_test.max())
plt.plot([min_val, max_val], [min_val, max_val], 'k--', lw=2) # y=x 对角线
plt.xlabel('Actual Rowing Distance')
plt.ylabel('Predicted Rowing Distance')
plt.title('Final Model: Actual vs. Predicted Rowing Distance (Test Set)')
plt.grid(True, linestyle='--', alpha=0.7)
plot_filename_actual_vs_pred = os.path.join(OUTPUT_PLOT_PATH, "final_model_actual_vs_predicted_wgangp.png")
try:
    plt.savefig(plot_filename_actual_vs_pred, dpi=300, bbox_inches='tight')
    print(f"最终测试集真实值 vs 预测值图已保存到: {plot_filename_actual_vs_pred}")
except Exception as e:
    print(f"保存最终测试集真实值 vs 预测值图时发生错误: {e}")
plt.show()

print("\n--- 整体流程执行完毕 ---")