In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import os
import random
import time
import joblib # 虽然BPNN不用joblib存，但保留以防万一有其他用途或习惯


from pathlib import Path 
CURRENT_DIR = Path.cwd()
PROJECT_ROOT = CURRENT_DIR.parent
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "output"



# --- 0. 全局配置 ---
# 数据路径
BASE_DATA_PATH = DATA_DIR
DEV_SET_FILE = DATA_DIR / "development_set_selected_features.xlsx"
TEST_SET_FILE = DATA_DIR / "final_test_set_selected_features.xlsx"
AUGMENTED_DATA_OUTPUT_FOLDER =  DATA_DIR / "augmented_outputs_bpnn" # BPNN的增强数据输出
MODEL_OUTPUT_PATH =  DATA_DIR / "trained_models_bpnn"      # BPNN的模型保存路径
OUTPUT_PLOT_PATH = OUTPUT_DIR # 图表导出路径

os.makedirs(AUGMENTED_DATA_OUTPUT_FOLDER, exist_ok=True)
os.makedirs(MODEL_OUTPUT_PATH, exist_ok=True)
os.makedirs(OUTPUT_PLOT_PATH, exist_ok=True)

TARGET_COLUMN = 'Rowing distance'
RANDOM_STATE = 42
N_SPLITS_KFOLD = 5 # K-Fold的折数

# WGAN-GP 预设参数 (与XGBoost脚本一致)
# CHO_LEVELS = np.array([0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2])
DEFAULT_WGAN_PARAMS = {
    'latent_dim': 100,
    'lambda_gp': 10,
    'n_critic': 5,
    'lr': 0.00005, # WGAN-GP学习率
    'batch_size': 32, # WGAN-GP的batch_size
    'epochs_for_cv': 500, # K-Fold内部动态增强时WGAN-GP使用的轮数
    'epochs_for_final_hpo': 2000 # 用于HPO或最终增强开发集时WGAN-GP的轮数 (原为epochs_for_final: 3000，可按需调整)
}

# BPNN 超参数搜索范围 (基于之前的讨论，考虑了正则化和适度简化)
BPNN_PARAM_GRID = {
    'learning_rate': [0.0005, 0.001, 0.005], # BPNN学习率
    'hidden_dims': [
        [64],
        [32, 16],
        [64, 32],
        [128, 64] # 略微增加一点复杂度的选项，因为有数据增强
    ],
    'batch_size': [16, 32, 64], # BPNN的batch_size
    'dropout_rate': [0.2, 0.3, 0.4, 0.5],
    'weight_decay': [1e-5, 1e-4, 5e-4, 1e-3]
}
N_ITER_BPNN_HPO = 25 # BPNN HPO的迭代次数 (原XGB是30)

# BPNN 训练时期数和耐心
EPOCHS_BPNN_HPO = 75       # BPNN HPO时每个配置的训练轮数 (原50)
PATIENCE_BPNN_HPO = 10      # BPNN HPO时的早停耐心 (原5)
EPOCHS_BPNN_CV_FINAL = 250  # BPNN CV折训练和最终模型训练的轮数 (原200)
PATIENCE_BPNN_CV_FINAL = 20 # BPNN CV折和最终模型早停耐心 (原15)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"将使用设备: {device}")

def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(RANDOM_STATE)

# --- 1. WGAN-GP 模型定义 (与XGBoost脚本一致) ---
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, output_dim)
        )
    def forward(self, z):
        return self.model(z)

class Critic(nn.Module):
    def __init__(self, input_dim):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128), nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.model(x)

def gradient_penalty(critic_model, real_samples, fake_samples, device_in_use):
    batch_size_gp = real_samples.size(0)
    alpha = torch.rand(batch_size_gp, 1, device=device_in_use).expand_as(real_samples)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = critic_model(interpolates)
    fake_grad_output = torch.ones_like(d_interpolates, device=device_in_use)
    gradients = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates, grad_outputs=fake_grad_output,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty_val = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty_val

# --- 2. WGAN-GP 训练与生成辅助函数 (与XGBoost脚本基本一致) ---
def train_and_generate_wgangp(input_original_df,
                              target_col_name,
                              wgan_hyperparams,
                              num_samples_to_generate,
                              current_device,
                              fold_num_for_logging=None,
                              output_augmented_data_path=None):
    log_prefix = f"[WGAN-GP"
    if isinstance(fold_num_for_logging, int):
        log_prefix += f" Fold {fold_num_for_logging}"
    elif isinstance(fold_num_for_logging, str):
        log_prefix += f" Stage {fold_num_for_logging}"
    log_prefix += "]"
    print(f"\n{log_prefix} 开始WGAN-GP处理，输入数据形状: {input_original_df.shape}")

    all_feature_names = input_original_df.columns.tolist()
    original_data_values = input_original_df.values.astype(np.float32)

    data_mean = np.mean(original_data_values, axis=0)
    data_std = np.std(original_data_values, axis=0)
    data_std[data_std == 0] = 1
    standardized_data = (original_data_values - data_mean) / data_std

    data_tensor = torch.tensor(standardized_data, dtype=torch.float32) # Ensure float32 for PyTorch
    dataset = TensorDataset(data_tensor)
    
    current_batch_size = min(wgan_hyperparams['batch_size'], len(dataset))
    if current_batch_size == 0:
        print(f"{log_prefix} 错误：WGAN-GP数据集为空或过小 ({len(dataset)} samples), 无法创建DataLoader。")
        return pd.DataFrame(columns=all_feature_names)
    
    use_drop_last = len(dataset) >= current_batch_size * 2
    dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True, drop_last=use_drop_last)
    
    if len(dataloader) == 0 and len(dataset) > 0 :
        dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True, drop_last=False)
        if len(dataloader) == 0 and len(dataset) > 0:
             print(f"{log_prefix} 错误: WGAN-GP DataLoader仍然为空。无法继续GAN训练。")
             return pd.DataFrame(columns=all_feature_names)

    num_features = standardized_data.shape[1]
    generator = Generator(wgan_hyperparams['latent_dim'], num_features).to(current_device)
    critic = Critic(num_features).to(current_device)
    optimizer_G = optim.Adam(generator.parameters(), lr=wgan_hyperparams['lr'], betas=(0.5, 0.9))
    optimizer_C = optim.Adam(critic.parameters(), lr=wgan_hyperparams['lr'], betas=(0.5, 0.9))

    print(f"{log_prefix} 开始WGAN-GP训练 ({wgan_hyperparams['epochs']} 轮)...")
    for epoch in range(wgan_hyperparams['epochs']):
        for i, (real_samples_batch,) in enumerate(dataloader):
            if real_samples_batch.shape[0] == 0: continue
            real_samples_batch = real_samples_batch.to(current_device)
            current_real_batch_size = real_samples_batch.size(0)

            for _ in range(wgan_hyperparams['n_critic']):
                optimizer_C.zero_grad()
                z = torch.randn(current_real_batch_size, wgan_hyperparams['latent_dim'], device=current_device)
                fake_samples_batch = generator(z)
                critic_real = critic(real_samples_batch)
                critic_fake = critic(fake_samples_batch.detach())
                gp = gradient_penalty(critic, real_samples_batch, fake_samples_batch, current_device)
                critic_loss = torch.mean(critic_fake) - torch.mean(critic_real) + wgan_hyperparams['lambda_gp'] * gp
                critic_loss.backward()
                optimizer_C.step()
            
            optimizer_G.zero_grad()
            z = torch.randn(current_real_batch_size, wgan_hyperparams['latent_dim'], device=current_device)
            generated_for_g_loss = generator(z)
            generator_loss = -torch.mean(critic(generated_for_g_loss))
            generator_loss.backward()
            optimizer_G.step()

        if (epoch + 1) % (max(1, wgan_hyperparams['epochs'] // 10)) == 0:
             print(f"{log_prefix} [Epoch {epoch+1}/{wgan_hyperparams['epochs']}] WGAN Critic Loss: {critic_loss.item():.4f}, Gen Loss: {generator_loss.item():.4f}")
    print(f"{log_prefix} WGAN-GP训练完成。")

    print(f"{log_prefix} 正在生成 {num_samples_to_generate} 个增强样本...")
    generator.eval()
    generated_samples_list = []
    remaining_samples = num_samples_to_generate
    gen_batch_size = wgan_hyperparams['batch_size']
    with torch.no_grad():
        while remaining_samples > 0:
            current_gen_size = min(gen_batch_size, remaining_samples)
            z_generate = torch.randn(current_gen_size, wgan_hyperparams['latent_dim'], device=current_device)
            generated_batch_std = generator(z_generate).detach().cpu().numpy()
            generated_samples_list.append(generated_batch_std)
            remaining_samples -= current_gen_size
    generated_standardized_data_np = np.concatenate(generated_samples_list, axis=0)

    generated_data_original_scale_np = generated_standardized_data_np * data_std + data_mean
    generated_data_df = pd.DataFrame(generated_data_original_scale_np, columns=all_feature_names)

    # if 'CHO' in generated_data_df.columns:
    #     cho_column_generated = generated_data_df['CHO'].values
    #     processed_cho = np.array([CHO_LEVELS[np.abs(CHO_LEVELS - val).argmin()] for val in cho_column_generated])
    #     generated_data_df['CHO'] = processed_cho
    for col_name in all_feature_names:
        # if col_name == 'CHO': continue # <--- 删除或注释掉这一行，让CHO也参与通用裁剪
        original_col_values = input_original_df[col_name]
        col_min_original = original_col_values.min()
        col_max_original = original_col_values.max()
        col_range = col_max_original - col_min_original
    
        # 这部分原有的1%范围扩展逻辑是好的，保持不变
        clip_min_for_col = col_min_original - 0.01 * col_range if col_range != 0 else col_min_original
        clip_max_for_col = col_max_original + 0.01 * col_range if col_range != 0 else col_max_original
    
        # 在这里添加针对 CHO 和 PRO（以及其他您认为需要非负的列）的特殊处理
        if col_name in ['CHO', 'PRO']: # 如果有其他列也需要确保非负，可以加入此列表
            clip_min_for_col = max(0, clip_min_for_col)
    
        generated_data_df[col_name] = np.clip(generated_data_df[col_name], clip_min_for_col, clip_max_for_col)
    print(f"{log_prefix} 后处理完成。")

    if output_augmented_data_path:
        try:
            os.makedirs(os.path.dirname(output_augmented_data_path), exist_ok=True)
            generated_data_df.to_excel(output_augmented_data_path, index=False)
            print(f"{log_prefix} 增强数据已保存到: {output_augmented_data_path}")
        except Exception as e:
            print(f"{log_prefix} 保存增强数据时发生错误: {e}")
    return generated_data_df

# --- 3. BPNN 模型定义 (与之前的BPNN脚本一致) ---
class BPNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim=1, dropout_rate=0.2):
        super(BPNN, self).__init__()
        layers = []
        current_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, h_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, output_dim))
        self.network = nn.Sequential(*layers)
    def forward(self, x):
        return self.network(x)

# --- 4. BPNN 训练和评估函数 (与之前的BPNN脚本一致，但y_scaler是必须的) ---
def train_evaluate_bpnn_fold(model_config, X_train_fold_scaled, y_train_fold_scaled,
                              X_val_fold_scaled, y_val_fold_scaled, y_scaler_for_unnorm, # y_scaler用于反归一化MAE
                              n_epochs, patience, current_device):
    input_dim = X_train_fold_scaled.shape[1]
    model = BPNN(input_dim, model_config['hidden_dims'], dropout_rate=model_config['dropout_rate']).to(current_device)
    optimizer = optim.Adam(model.parameters(), lr=model_config['learning_rate'], weight_decay=model_config.get('weight_decay', 0))
    criterion = nn.MSELoss()

    train_dataset = TensorDataset(torch.FloatTensor(X_train_fold_scaled).to(current_device),
                                  torch.FloatTensor(y_train_fold_scaled).reshape(-1,1).to(current_device))
    train_loader = DataLoader(train_dataset, batch_size=model_config['batch_size'], shuffle=True)
    
    val_dataset = TensorDataset(torch.FloatTensor(X_val_fold_scaled).to(current_device),
                                torch.FloatTensor(y_val_fold_scaled).reshape(-1,1).to(current_device))
    val_loader = DataLoader(val_dataset, batch_size=model_config['batch_size'], shuffle=False)

    epoch_train_mae_unnormalized = []
    epoch_val_mae_unnormalized = []
    best_val_loss_mae_unnorm = float('inf')
    epochs_no_improve = 0
    # best_model_state = None # Storing best model state in memory

    for epoch in range(n_epochs):
        model.train()
        current_epoch_train_preds_unnorm = []
        current_epoch_train_targets_unnorm = []
        for batch_X, batch_y_scaled in train_loader:
            optimizer.zero_grad()
            outputs_scaled = model(batch_X)
            loss = criterion(outputs_scaled, batch_y_scaled)
            loss.backward()
            optimizer.step()

            batch_preds_unnorm = y_scaler_for_unnorm.inverse_transform(outputs_scaled.detach().cpu().numpy())
            batch_targets_unnorm = y_scaler_for_unnorm.inverse_transform(batch_y_scaled.detach().cpu().numpy())
            current_epoch_train_preds_unnorm.extend(batch_preds_unnorm.flatten())
            current_epoch_train_targets_unnorm.extend(batch_targets_unnorm.flatten())
        
        epoch_train_mae_unnormalized.append(mean_absolute_error(current_epoch_train_targets_unnorm, current_epoch_train_preds_unnorm))

        model.eval()
        current_epoch_val_preds_unnorm = []
        current_epoch_val_targets_unnorm = []
        with torch.no_grad():
            for batch_X_val, batch_y_val_scaled in val_loader:
                val_outputs_scaled = model(batch_X_val)
                batch_val_preds_unnorm = y_scaler_for_unnorm.inverse_transform(val_outputs_scaled.cpu().numpy())
                batch_val_targets_unnorm = y_scaler_for_unnorm.inverse_transform(batch_y_val_scaled.cpu().numpy())
                current_epoch_val_preds_unnorm.extend(batch_val_preds_unnorm.flatten())
                current_epoch_val_targets_unnorm.extend(batch_val_targets_unnorm.flatten())
        
        current_val_mae_unnorm = mean_absolute_error(current_epoch_val_targets_unnorm, current_epoch_val_preds_unnorm)
        epoch_val_mae_unnormalized.append(current_val_mae_unnorm)
        
        if current_val_mae_unnorm < best_val_loss_mae_unnorm:
            best_val_loss_mae_unnorm = current_val_mae_unnorm
            epochs_no_improve = 0
            # best_model_state = model.state_dict() 
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve >= patience:
            # print(f"  BPNN Early stopping at epoch {epoch+1}.")
            break
            
    # if best_model_state: model.load_state_dict(best_model_state) # Load best model for final evaluation

    # Final evaluation on scaled validation set, then unscale
    model.eval()
    all_y_pred_val_scaled_list, all_y_val_scaled_list = [], []
    with torch.no_grad():
        for batch_X_val, batch_y_val_s in val_loader:
            val_outputs_s = model(batch_X_val)
            all_y_pred_val_scaled_list.extend(val_outputs_s.cpu().numpy())
            all_y_val_scaled_list.extend(batch_y_val_s.cpu().numpy())
    y_pred_val_unnorm = y_scaler_for_unnorm.inverse_transform(np.array(all_y_pred_val_scaled_list)).flatten()
    y_val_unnorm = y_scaler_for_unnorm.inverse_transform(np.array(all_y_val_scaled_list)).flatten()

    # Final evaluation on scaled training set, then unscale
    all_y_pred_train_scaled_list, all_y_train_scaled_list = [], []
    with torch.no_grad():
        for batch_X_train, batch_y_train_s in train_loader:
            train_outputs_s = model(batch_X_train)
            all_y_pred_train_scaled_list.extend(train_outputs_s.cpu().numpy())
            all_y_train_scaled_list.extend(batch_y_train_s.cpu().numpy())
    y_pred_train_unnorm = y_scaler_for_unnorm.inverse_transform(np.array(all_y_pred_train_scaled_list)).flatten()
    y_train_unnorm = y_scaler_for_unnorm.inverse_transform(np.array(all_y_train_scaled_list)).flatten()

    return {
        'model_state': model.state_dict(), # Return the state of the best performing model on this fold
        'mae_val': mean_absolute_error(y_val_unnorm, y_pred_val_unnorm),
        'rmse_val': np.sqrt(mean_squared_error(y_val_unnorm, y_pred_val_unnorm)),
        'r2_val': r2_score(y_val_unnorm, y_pred_val_unnorm),
        'mae_train': mean_absolute_error(y_train_unnorm, y_pred_train_unnorm),
        'rmse_train': np.sqrt(mean_squared_error(y_train_unnorm, y_pred_train_unnorm)),
        'r2_train': r2_score(y_train_unnorm, y_pred_train_unnorm),
        'train_loss_curve_unnorm': epoch_train_mae_unnormalized,
        'val_loss_curve_unnorm': epoch_val_mae_unnormalized
    }

# --- 5. 主流程开始 ---
try:
    development_df_original = pd.read_excel(DEV_SET_FILE)
    final_test_df_original = pd.read_excel(TEST_SET_FILE)
    print(f"开发集形状: {development_df_original.shape}, 最终测试集形状: {final_test_df_original.shape}")
except FileNotFoundError as e:
    print(f"错误: 开发集或测试集文件未找到。请检查路径: {e}")
    exit()

X_dev_original_df = development_df_original.drop(columns=[TARGET_COLUMN])
y_dev_original_series = development_df_original[TARGET_COLUMN]
X_final_test_df = final_test_df_original.drop(columns=[TARGET_COLUMN])
y_final_test_series = final_test_df_original[TARGET_COLUMN]

# --- 步骤一：BPNN 超参数调优 (使用WGAN-GP增强的开发集) ---
print("\n--- 步骤一：BPNN 超参数调优 (使用WGAN-GP增强数据) ---")
print("为BPNN HPO生成开发集的增强版本...")
num_augmented_samples_for_hpo = len(development_df_original) * 1 # 增强1倍数据量
current_wgan_hpo_params = DEFAULT_WGAN_PARAMS.copy()
current_wgan_hpo_params['epochs'] = DEFAULT_WGAN_PARAMS['epochs_for_final_hpo'] # 使用较多轮数训练GAN
augmented_dev_for_hpo_output_path = os.path.join(AUGMENTED_DATA_OUTPUT_FOLDER, "bpnn_augmented_dev_for_hpo.xlsx")

# WGAN-GP增强整个开发集
augmented_dev_for_hpo_df = train_and_generate_wgangp(
    input_original_df=development_df_original.copy(), # 传递包含目标列的完整DataFrame
    target_col_name=TARGET_COLUMN,
    wgan_hyperparams=current_wgan_hpo_params,
    num_samples_to_generate=num_augmented_samples_for_hpo,
    current_device=device,
    fold_num_for_logging="BPNN_HPO_Dev_Set",
    output_augmented_data_path=augmented_dev_for_hpo_output_path
)
if augmented_dev_for_hpo_df.empty:
    print("错误：为BPNN HPO生成的增强数据为空，无法继续。")
    exit()

# 合并原始开发集和增强开发集，用于BPNN HPO
X_aug_dev_hpo_df = augmented_dev_for_hpo_df.drop(columns=[TARGET_COLUMN])
y_aug_dev_hpo_series = augmented_dev_for_hpo_df[TARGET_COLUMN]
X_combined_dev_for_hpo_df = pd.concat([X_dev_original_df, X_aug_dev_hpo_df], ignore_index=True)
y_combined_dev_for_hpo_series = pd.concat([y_dev_original_series, y_aug_dev_hpo_series], ignore_index=True)
print(f"用于BPNN HPO的总数据形状 (X): {X_combined_dev_for_hpo_df.shape}, (y): {y_combined_dev_for_hpo_series.shape}")

# BPNN HPO 开始
print("开始BPNN超参数搜索 (在WGAN-GP增强数据上)...")
best_bpnn_params = None
best_avg_hpo_val_mae = float('inf')
hpo_results_log = []

sampled_bpnn_configs = []
for _ in range(N_ITER_BPNN_HPO):
    config = {}
    for key, values in BPNN_PARAM_GRID.items():
        config[key] = random.choice(values)
    if config not in sampled_bpnn_configs: # 避免重复 (可选)
        sampled_bpnn_configs.append(config)

print(f"将尝试 {len(sampled_bpnn_configs)} 组BPNN超参数组合。")

# NumPy转换以提高效率，用于HPO的CV
X_hpo_np = X_combined_dev_for_hpo_df.values
y_hpo_np = y_combined_dev_for_hpo_series.values

for i_config, current_bpnn_hpo_params in enumerate(sampled_bpnn_configs):
    print(f"\n--- BPNN HPO 配置 {i_config+1}/{len(sampled_bpnn_configs)} ---")
    print(current_bpnn_hpo_params)
    
    kf_hpo = KFold(n_splits=3, shuffle=True, random_state=RANDOM_STATE + i_config) # HPO内部用3折CV
    fold_hpo_val_maes = []

    for fold_idx_hpo, (train_idx_hpo, val_idx_hpo) in enumerate(kf_hpo.split(X_hpo_np, y_hpo_np)):
        X_train_hpo_fold, X_val_hpo_fold = X_hpo_np[train_idx_hpo], X_hpo_np[val_idx_hpo]
        y_train_hpo_fold, y_val_hpo_fold = y_hpo_np[train_idx_hpo], y_hpo_np[val_idx_hpo]

        # BPNN HPO内部的特征和目标缩放
        x_scaler_hpo_fold = StandardScaler().fit(X_train_hpo_fold)
        X_train_hpo_fold_scaled = x_scaler_hpo_fold.transform(X_train_hpo_fold)
        X_val_hpo_fold_scaled = x_scaler_hpo_fold.transform(X_val_hpo_fold)

        y_scaler_hpo_fold = StandardScaler().fit(y_train_hpo_fold.reshape(-1, 1))
        y_train_hpo_fold_scaled = y_scaler_hpo_fold.transform(y_train_hpo_fold.reshape(-1, 1)).flatten()
        y_val_hpo_fold_scaled = y_scaler_hpo_fold.transform(y_val_hpo_fold.reshape(-1, 1)).flatten()
        
        hpo_fold_results = train_evaluate_bpnn_fold(
            current_bpnn_hpo_params, X_train_hpo_fold_scaled, y_train_hpo_fold_scaled,
            X_val_hpo_fold_scaled, y_val_hpo_fold_scaled, y_scaler_hpo_fold, # 传递y_scaler
            n_epochs=EPOCHS_BPNN_HPO, patience=PATIENCE_BPNN_HPO, current_device=device
        )
        fold_hpo_val_maes.append(hpo_fold_results['mae_val'])
    
    avg_hpo_fold_val_mae = np.mean(fold_hpo_val_maes)
    hpo_results_log.append({'params': current_bpnn_hpo_params, 'avg_val_mae': avg_hpo_fold_val_mae})
    print(f"BPNN HPO 配置 {i_config+1} 平均验证 MAE: {avg_hpo_fold_val_mae:.4f}")

    if avg_hpo_fold_val_mae < best_avg_hpo_val_mae:
        best_avg_hpo_val_mae = avg_hpo_fold_val_mae
        best_bpnn_params = current_bpnn_hpo_params

print("\n--- BPNN 超参数调优完成 ---")
if best_bpnn_params:
    print(f"找到的最佳BPNN超参数: {best_bpnn_params}")
    print(f"最佳BPNN HPO平均验证 MAE: {best_avg_hpo_val_mae:.4f}")
else:
    print("错误：未能找到最佳BPNN参数。将使用第一组尝试的参数。")
    best_bpnn_params = sampled_bpnn_configs[0] if sampled_bpnn_configs else BPNN_PARAM_GRID[0]


# --- 步骤二：K-折交叉验证 (BPNN + 动态WGAN-GP增强) ---
print(f"\n--- 步骤二：在开发集上进行 {N_SPLITS_KFOLD}-折交叉验证 (BPNN + 动态WGAN-GP增强) ---")
kf_cv = KFold(n_splits=N_SPLITS_KFOLD, shuffle=True, random_state=RANDOM_STATE)

kfold_cv_val_metrics_list_bpnn = []
kfold_cv_train_metrics_list_bpnn = []
cv_bpnn_train_mae_curves_unnorm = []
cv_bpnn_val_mae_curves_unnorm = []

current_wgan_cv_params = DEFAULT_WGAN_PARAMS.copy()
current_wgan_cv_params['epochs'] = DEFAULT_WGAN_PARAMS['epochs_for_cv'] # CV内部用较少轮数训练GAN
augmentation_factor_cv = 1 # CV每折增强1倍数据量

for fold_idx_cv, (train_indices_cv, val_indices_cv) in enumerate(kf_cv.split(development_df_original)):
    print(f"\n--- BPNN K-Fold: 第 {fold_idx_cv + 1}/{N_SPLITS_KFOLD} 折 ---")
    cv_train_original_fold_df = development_df_original.iloc[train_indices_cv]
    cv_val_original_fold_df = development_df_original.iloc[val_indices_cv]

    X_cv_val_fold_df = cv_val_original_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_val_fold_series = cv_val_original_fold_df[TARGET_COLUMN]

    print(f"当前CV训练集（原始）形状: {cv_train_original_fold_df.shape}")
    num_augmented_samples_cv_fold = len(cv_train_original_fold_df) * augmentation_factor_cv
    
    dynamic_wgan_cv_params = current_wgan_cv_params.copy()
    # 动态调整GAN的batch_size以适应可能较小的CV训练折数据
    dynamic_wgan_cv_params['batch_size'] = min(current_wgan_cv_params['batch_size'], max(1, len(cv_train_original_fold_df) // 2 if len(cv_train_original_fold_df) > 1 else 1))
    
    # 为当前CV训练折动态生成WGAN-GP增强数据
    cv_augmented_fold_df = train_and_generate_wgangp(
        input_original_df=cv_train_original_fold_df.copy(), # GAN训练基于当前折的原始训练数据
        target_col_name=TARGET_COLUMN,
        wgan_hyperparams=dynamic_wgan_cv_params,
        num_samples_to_generate=num_augmented_samples_cv_fold,
        current_device=device,
        fold_num_for_logging=(fold_idx_cv + 1),
        output_augmented_data_path=None # CV内部一般不保存增强数据文件
    )
    if cv_augmented_fold_df.empty:
        print(f"警告：Fold {fold_idx_cv + 1} 的WGAN-GP增强数据为空，跳过此折的BPNN训练。")
        cv_bpnn_train_mae_curves_unnorm.append([])
        cv_bpnn_val_mae_curves_unnorm.append([])
        continue

    # 合并当前CV折的原始训练数据和增强数据
    X_cv_train_original_fold_df = cv_train_original_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_train_original_fold_series = cv_train_original_fold_df[TARGET_COLUMN]
    X_cv_augmented_fold_df = cv_augmented_fold_df.drop(columns=[TARGET_COLUMN])
    y_cv_augmented_fold_series = cv_augmented_fold_df[TARGET_COLUMN]

    X_cv_train_combined_fold_df = pd.concat([X_cv_train_original_fold_df, X_cv_augmented_fold_df], ignore_index=True)
    y_cv_train_combined_fold_series = pd.concat([y_cv_train_original_fold_series, y_cv_augmented_fold_series], ignore_index=True)
    print(f"当前CV训练集（原始+增强后）形状 (X): {X_cv_train_combined_fold_df.shape}, (y): {y_cv_train_combined_fold_series.shape}")

    # BPNN的特征和目标缩放（基于当前合并后的CV训练折数据）
    x_scaler_cv_fold = StandardScaler().fit(X_cv_train_combined_fold_df.values)
    X_cv_train_combined_fold_scaled = x_scaler_cv_fold.transform(X_cv_train_combined_fold_df.values)
    X_cv_val_fold_scaled = x_scaler_cv_fold.transform(X_cv_val_fold_df.values) # 用训练集的scaler转换验证集

    y_scaler_cv_fold = StandardScaler().fit(y_cv_train_combined_fold_series.values.reshape(-1, 1))
    y_cv_train_combined_fold_scaled = y_scaler_cv_fold.transform(y_cv_train_combined_fold_series.values.reshape(-1, 1)).flatten()
    y_cv_val_fold_scaled = y_scaler_cv_fold.transform(y_cv_val_fold_series.values.reshape(-1, 1)).flatten()

    print(f"Fold {fold_idx_cv + 1}: 开始训练BPNN模型...")
    bpnn_fold_results = train_evaluate_bpnn_fold(
        best_bpnn_params, X_cv_train_combined_fold_scaled, y_cv_train_combined_fold_scaled,
        X_cv_val_fold_scaled, y_cv_val_fold_scaled, y_scaler_cv_fold, # 传递y_scaler
        n_epochs=EPOCHS_BPNN_CV_FINAL, patience=PATIENCE_BPNN_CV_FINAL, current_device=device
    )
    
    kfold_cv_val_metrics_list_bpnn.append({'fold': fold_idx_cv + 1, 'MAE': bpnn_fold_results['mae_val'], 'RMSE': bpnn_fold_results['rmse_val'], 'R2': bpnn_fold_results['r2_val']})
    kfold_cv_train_metrics_list_bpnn.append({'fold': fold_idx_cv + 1, 'MAE': bpnn_fold_results['mae_train'], 'RMSE': bpnn_fold_results['rmse_train'], 'R2': bpnn_fold_results['r2_train']})
    cv_bpnn_train_mae_curves_unnorm.append(bpnn_fold_results['train_loss_curve_unnorm'])
    cv_bpnn_val_mae_curves_unnorm.append(bpnn_fold_results['val_loss_curve_unnorm'])
    
    print(f"Fold {fold_idx_cv + 1} - BPNN CV Train MAE: {bpnn_fold_results['mae_train']:.4f}, R2: {bpnn_fold_results['r2_train']:.4f} | BPNN CV Val MAE: {bpnn_fold_results['mae_val']:.4f}, R2: {bpnn_fold_results['r2_val']:.4f}")

avg_kfold_cv_val_metrics_bpnn_df = pd.DataFrame(kfold_cv_val_metrics_list_bpnn)
print("\n--- BPNN K-折交叉验证平均CV验证性能 (开发集, WGAN-GP动态增强) ---")
if not avg_kfold_cv_val_metrics_bpnn_df.empty:
    avg_mae_cv_val_bpnn = avg_kfold_cv_val_metrics_bpnn_df['MAE'].mean()
    avg_rmse_cv_val_bpnn = avg_kfold_cv_val_metrics_bpnn_df['RMSE'].mean()
    avg_r2_cv_val_bpnn = avg_kfold_cv_val_metrics_bpnn_df['R2'].mean()
    print(f"BPNN 平均 CV 验证集 MAE: {avg_mae_cv_val_bpnn:.4f}, RMSE: {avg_rmse_cv_val_bpnn:.4f}, R2: {avg_r2_cv_val_bpnn:.4f}")
else:
    print("BPNN K-Fold CV验证结果为空。")
    avg_mae_cv_val_bpnn, avg_r2_cv_val_bpnn = np.nan, np.nan

avg_kfold_cv_train_metrics_bpnn_df = pd.DataFrame(kfold_cv_train_metrics_list_bpnn)
print("\n--- BPNN K-折交叉验证平均CV训练性能 (开发集, WGAN-GP动态增强) ---")
if not avg_kfold_cv_train_metrics_bpnn_df.empty:
    avg_mae_cv_train_bpnn = avg_kfold_cv_train_metrics_bpnn_df['MAE'].mean()
    avg_r2_cv_train_bpnn = avg_kfold_cv_train_metrics_bpnn_df['R2'].mean()
    print(f"BPNN 平均 CV 训练集 MAE: {avg_mae_cv_train_bpnn:.4f}, R2: {avg_r2_cv_train_bpnn:.4f}")
else:
    print("BPNN K-Fold CV训练结果为空。")
    avg_mae_cv_train_bpnn, avg_r2_cv_train_bpnn = np.nan, np.nan

# 绘制BPNN的K-Fold CV性能图表
if not np.isnan(avg_mae_cv_val_bpnn) and not np.isnan(avg_mae_cv_train_bpnn):
    print("\n--- 步骤二结束：生成BPNN K-Fold CV性能图表 ---")
    metrics_plot_names_en = ['MAE', 'R2 Score']
    values_cv_val_plot_bpnn = [avg_mae_cv_val_bpnn, avg_r2_cv_val_bpnn]
    values_cv_train_plot_bpnn = [avg_mae_cv_train_bpnn, avg_r2_cv_train_bpnn]
    x_axis_plot = np.arange(len(metrics_plot_names_en))
    plt.figure(figsize=(10, 6))
    plt.bar(x_axis_plot - 0.2, values_cv_train_plot_bpnn, width=0.4, label='BPNN CV Train Avg.', align='center')
    plt.bar(x_axis_plot + 0.2, values_cv_val_plot_bpnn, width=0.4, label='BPNN CV Validation Avg.', align='center')
    plt.xticks(x_axis_plot, metrics_plot_names_en)
    plt.ylabel('Score')
    plt.title('BPNN Average K-Fold CV Metrics (WGAN-GP Augmented)')
    plt.legend(); plt.grid(True, linestyle='--', alpha=0.7)
    plot_filename_metrics_bpnn = os.path.join(OUTPUT_PLOT_PATH, "bpnn_wgan_kfold_avg_eval_metrics.png")
    try: plt.savefig(plot_filename_metrics_bpnn, dpi=300, bbox_inches='tight'); print(f"BPNN图表1已保存: {plot_filename_metrics_bpnn}")
    except Exception as e: print(f"保存BPNN图表1错误: {e}")
    plt.show()

    if any(cv_bpnn_train_mae_curves_unnorm) and any(cv_bpnn_val_mae_curves_unnorm):
        plt.figure(figsize=(14, 7))
        max_epochs_cv_bpnn = 0
        for i in range(len(cv_bpnn_train_mae_curves_unnorm)):
            if cv_bpnn_train_mae_curves_unnorm[i] and cv_bpnn_val_mae_curves_unnorm[i]:
                epochs_this_fold = len(cv_bpnn_train_mae_curves_unnorm[i])
                max_epochs_cv_bpnn = max(max_epochs_cv_bpnn, epochs_this_fold)
                plt.plot(range(1, epochs_this_fold + 1), cv_bpnn_train_mae_curves_unnorm[i], label=f'BPNN CV Train Fold {i+1}', linestyle='-')
                plt.plot(range(1, epochs_this_fold + 1), cv_bpnn_val_mae_curves_unnorm[i], label=f'BPNN CV Val Fold {i+1}', linestyle='--')
        plt.xlabel('Epoch'); plt.ylabel('Mean Absolute Error (MAE) - Unnormalized')
        plt.title('BPNN Training and CV Validation MAE (Unnormalized) per Fold (WGAN-GP Augmented)')
        if max_epochs_cv_bpnn > 0: plt.xlim(1, max_epochs_cv_bpnn)
        if len(cv_bpnn_train_mae_curves_unnorm) <= 5: plt.legend(loc='upper right')
        plt.grid(True, linestyle='--', alpha=0.7)
        plot_filename_loss_bpnn = os.path.join(OUTPUT_PLOT_PATH, "bpnn_wgan_kfold_mae_loss_per_fold.png")
        try: plt.savefig(plot_filename_loss_bpnn, dpi=300, bbox_inches='tight'); print(f"BPNN图表2已保存: {plot_filename_loss_bpnn}")
        except Exception as e: print(f"保存BPNN图表2错误: {e}")
        plt.show()
    else: print("没有收集到足够的BPNN MAE曲线数据用于绘制图表2。")
else: print("由于BPNN K-Fold平均结果为空或NaN，跳过图表绘制。")


# --- 步骤三：训练最终BPNN模型 ---
print("\n--- 步骤三：训练最终BPNN模型 (使用WGAN-GP增强的完整开发集) ---")
print("为最终BPNN模型生成开发集的完整增强版本...")
final_wgan_params_bpnn = DEFAULT_WGAN_PARAMS.copy()
final_wgan_params_bpnn['epochs'] = DEFAULT_WGAN_PARAMS['epochs_for_final_hpo'] # 使用与HPO时相同的GAN轮数
num_augmented_samples_final_bpnn = len(development_df_original) * 2 # 最终模型增强2倍数据量
final_augmented_dev_output_path_bpnn = os.path.join(AUGMENTED_DATA_OUTPUT_FOLDER, "bpnn_augmented_dev_for_final_model.xlsx")

final_augmented_dev_df_bpnn = train_and_generate_wgangp(
    input_original_df=development_df_original.copy(),
    target_col_name=TARGET_COLUMN,
    wgan_hyperparams=final_wgan_params_bpnn,
    num_samples_to_generate=num_augmented_samples_final_bpnn,
    current_device=device,
    fold_num_for_logging="BPNN_Final_Dev_Set_Augmentation",
    output_augmented_data_path=final_augmented_dev_output_path_bpnn
)
if final_augmented_dev_df_bpnn.empty:
    print("错误：为最终BPNN模型生成的WGAN-GP增强数据为空。")
    exit()

# 合并原始开发集和最终增强数据
X_final_aug_dev_df = final_augmented_dev_df_bpnn.drop(columns=[TARGET_COLUMN])
y_final_aug_dev_series = final_augmented_dev_df_bpnn[TARGET_COLUMN]
X_train_final_model_df = pd.concat([X_dev_original_df, X_final_aug_dev_df], ignore_index=True)
y_train_final_model_series = pd.concat([y_dev_original_series, y_final_aug_dev_series], ignore_index=True)
print(f"用于训练最终BPNN模型的总数据形状 (X): {X_train_final_model_df.shape}, (y): {y_train_final_model_series.shape}")

# 为最终BPNN模型准备数据和缩放器
x_scaler_final_bpnn = StandardScaler().fit(X_train_final_model_df.values)
X_train_final_model_scaled = x_scaler_final_bpnn.transform(X_train_final_model_df.values)

y_scaler_final_bpnn = StandardScaler().fit(y_train_final_model_series.values.reshape(-1, 1))
y_train_final_model_scaled = y_scaler_final_bpnn.transform(y_train_final_model_series.values.reshape(-1, 1)).flatten()

# 最终模型训练，也使用一部分数据做早停验证
if len(X_train_final_model_scaled) > 10:
    X_fm_train_s, X_fm_val_s, y_fm_train_s, y_fm_val_s = train_test_split(
        X_train_final_model_scaled, y_train_final_model_scaled, test_size=0.1, random_state=RANDOM_STATE
    )
else: # 数据太少，不用验证集
    X_fm_train_s, y_fm_train_s = X_train_final_model_scaled, y_train_final_model_scaled
    X_fm_val_s, y_fm_val_s = X_train_final_model_scaled, y_train_final_model_scaled # 用自身做名义上的验证

print("开始训练最终BPNN模型...")
final_bpnn_model_results = train_evaluate_bpnn_fold(
    best_bpnn_params, X_fm_train_s, y_fm_train_s,
    X_fm_val_s, y_fm_val_s, y_scaler_final_bpnn, # 使用最终的y_scaler
    n_epochs=EPOCHS_BPNN_CV_FINAL, patience=PATIENCE_BPNN_CV_FINAL, current_device=device
)
print("最终BPNN模型训练完成。")

final_bpnn_model_path = os.path.join(MODEL_OUTPUT_PATH, "final_bpnn_wgangp_model.pth")
torch.save(final_bpnn_model_results['model_state'], final_bpnn_model_path) # 保存模型参数
# 保存用于最终测试集预测的scalers
joblib.dump(x_scaler_final_bpnn, os.path.join(MODEL_OUTPUT_PATH, "final_bpnn_x_scaler.joblib"))
joblib.dump(y_scaler_final_bpnn, os.path.join(MODEL_OUTPUT_PATH, "final_bpnn_y_scaler.joblib"))
print(f"最终BPNN模型参数已保存到: {final_bpnn_model_path}")
print(f"最终BPNN X/Y Scalers 已保存到: {MODEL_OUTPUT_PATH}")

# --- 步骤四：最终无偏评估 (在“最终测试集”上) ---
print("\n--- 步骤四：在最终测试集上进行BPNN无偏评估 ---")
# 加载最终BPNN模型和scalers
final_bpnn_model_loaded = BPNN(X_final_test_df.shape[1], best_bpnn_params['hidden_dims'], dropout_rate=best_bpnn_params['dropout_rate']).to(device)
final_bpnn_model_loaded.load_state_dict(torch.load(final_bpnn_model_path, map_location=device))
final_bpnn_model_loaded.eval()

x_scaler_final_loaded = joblib.load(os.path.join(MODEL_OUTPUT_PATH, "final_bpnn_x_scaler.joblib"))
y_scaler_final_loaded = joblib.load(os.path.join(MODEL_OUTPUT_PATH, "final_bpnn_y_scaler.joblib"))

# 准备测试集数据
X_final_test_scaled = x_scaler_final_loaded.transform(X_final_test_df.values)
X_final_test_tensor = torch.FloatTensor(X_final_test_scaled).to(device)

with torch.no_grad():
    y_pred_final_test_scaled_tensor = final_bpnn_model_loaded(X_final_test_tensor)
y_pred_final_test_scaled_np = y_pred_final_test_scaled_tensor.cpu().numpy()
y_pred_final_test_unnorm = y_scaler_final_loaded.inverse_transform(y_pred_final_test_scaled_np).flatten()

mae_final_bpnn = mean_absolute_error(y_final_test_series.values, y_pred_final_test_unnorm)
rmse_final_bpnn = np.sqrt(mean_squared_error(y_final_test_series.values, y_pred_final_test_unnorm))
r2_final_bpnn = r2_score(y_final_test_series.values, y_pred_final_test_unnorm)
print("--- 最终BPNN模型在最终测试集上的性能 ---")
print(f"MAE: {mae_final_bpnn:.4f}, RMSE: {rmse_final_bpnn:.4f}, R2 Score: {r2_final_bpnn:.4f}")

# 绘制最终测试集 真实值 vs 预测值 图
print("\n--- 步骤四结束：生成BPNN最终测试集真实值 vs 预测值图 ---")
plt.figure(figsize=(8, 8))
plt.scatter(y_final_test_series.values, y_pred_final_test_unnorm, alpha=0.7, edgecolors='w', linewidth=0.5)
min_val = min(y_final_test_series.min(), y_pred_final_test_unnorm.min())
max_val = max(y_final_test_series.max(), y_pred_final_test_unnorm.max())
plt.plot([min_val, max_val], [min_val, max_val], 'k--', lw=2)
plt.xlabel('Actual Rowing Distance'); plt.ylabel('Predicted Rowing Distance (BPNN)')
plt.title('BPNN Final Model (WGAN-GP Aug): Actual vs. Predicted (Test Set)')
plt.grid(True, linestyle='--', alpha=0.7)
plot_filename_actual_vs_pred_bpnn = os.path.join(OUTPUT_PLOT_PATH, "bpnn_wgan_final_actual_vs_predicted.png")
try: plt.savefig(plot_filename_actual_vs_pred_bpnn, dpi=300, bbox_inches='tight'); print(f"BPNN最终测试图已保存: {plot_filename_actual_vs_pred_bpnn}")
except Exception as e: print(f"保存BPNN最终测试图错误: {e}")
plt.show()

print("\n--- BPNN + WGAN-GP 整体流程执行完毕 ---")