In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt # 暂时保留，虽然此代码不主动绘图，但调试时可能有用
import os
from pathlib import Path 
CURRENT_DIR = Path.cwd()
PROJECT_ROOT = CURRENT_DIR.parent
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "output"

# --- 1. 配置参数 ---
# 数据路径 (你的 "data1.xlsx" 包含20个输入特征 + 1个目标 "Rowing distance")
data_file_path = DATA_DIR / "development_set_selected_features.xlsx"
# 输出增强数据的文件路径
output_folder_path = DATA_DIR
output_file_name = "development_set_selected_features_迭代10000.xlsx"
full_output_path = os.path.join(output_folder_path, output_file_name)


# GAN 超参数
latent_dim = 100  # 潜变量维度
lambda_gp = 10    # 梯度惩罚的系数
n_critic = 5      # 每轮训练生成器一次，训练判别器的次数
epochs = 10000    # 训练轮数 (对于GAN，通常需要较多轮次，请根据实际情况调整)
batch_size = 32   # 批处理大小 (可根据你的数据量和内存调整，171条数据，32或64均可)
lr = 0.00005      # 学习率

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"将使用设备: {device}")

# --- 2. 数据加载和预处理 ---
try:
    original_data_df = pd.read_excel(data_file_path)
    print(f"原始数据已从 '{data_file_path}' 加载，形状: {original_data_df.shape}")
except FileNotFoundError:
    print(f"错误: 文件 '{data_file_path}' 未找到。请检查路径。")
    exit()
except Exception as e:
    print(f"加载数据时发生错误: {e}")
    exit()

# 获取所有列名 (这21列都将作为连续特征处理)
all_feature_names = original_data_df.columns.tolist()
print(f"所有特征 ({len(all_feature_names)}): {all_feature_names}")

# 将所有特征转换为float32 numpy数组
original_data_values = original_data_df.values.astype(np.float32)

# (A) 标准化所有特征
data_mean = np.mean(original_data_values, axis=0)
data_std = np.std(original_data_values, axis=0)
# 防止标准差为0的情况 (如果某列值都相同)
data_std[data_std == 0] = 1 
standardized_data = (original_data_values - data_mean) / data_std

# (B) 创建DataLoader
data_tensor = torch.tensor(standardized_data)
dataset = TensorDataset(data_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) # drop_last=True 如果样本数不能被batch_size整除

# --- 3. 定义GAN模型 (Generator 和 Critic) ---
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim) # 输出层没有激活函数，以生成连续值
        )
    def forward(self, z):
        return self.model(z)

class Critic(nn.Module):
    def __init__(self, input_dim):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2, inplace=True), # LeakyReLU 通常在判别器中效果更好
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1) # 输出一个标量评分
        )
    def forward(self, x):
        return self.model(x)

# --- 4. 定义梯度惩罚函数 ---
def gradient_penalty(critic_model, real_samples, fake_samples, device_in_use):
    batch_size_gp = real_samples.size(0)
    # Alpha L ~ U[0,1] for each sample in batch
    alpha = torch.rand(batch_size_gp, 1, device=device_in_use)
    alpha = alpha.expand_as(real_samples) # Expand to real_samples size

    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = critic_model(interpolates)

    # Use a "fake" gradient tensor of all ones
    fake_grad_output = torch.ones_like(d_interpolates, device=device_in_use)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake_grad_output,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1) # Flatten
    gradient_penalty_val = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty_val

# --- 5. 初始化模型、优化器 ---
num_features = standardized_data.shape[1] # 应该是21

generator = Generator(latent_dim, num_features).to(device)
critic = Critic(num_features).to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9)) # Adam通常比RMSprop更稳定
optimizer_C = optim.Adam(critic.parameters(), lr=lr, betas=(0.5, 0.9))

# --- 6. 训练循环 ---
print("\n开始WGAN-GP训练...")
for epoch in range(epochs):
    for i, (real_samples_batch,) in enumerate(dataloader):
        real_samples_batch = real_samples_batch.to(device)
        current_batch_size = real_samples_batch.size(0)

        # ---------------------
        #  训练判别器 (Critic)
        # ---------------------
        for _ in range(n_critic):
            optimizer_C.zero_grad()

            # 从潜空间采样噪声作为生成器输入
            z = torch.randn(current_batch_size, latent_dim, device=device)
            fake_samples_batch = generator(z)

            # 计算判别器对真实样本和伪造样本的评分
            critic_real = critic(real_samples_batch)
            critic_fake = critic(fake_samples_batch.detach()) # detach以避免更新生成器

            # 计算梯度惩罚
            gp = gradient_penalty(critic, real_samples_batch, fake_samples_batch, device)

            # 判别器损失 (Wasserstein距离 + 梯度惩罚)
            critic_loss = torch.mean(critic_fake) - torch.mean(critic_real) + lambda_gp * gp
            
            critic_loss.backward()
            optimizer_C.step()

        # -----------------
        #  训练生成器
        # -----------------
        optimizer_G.zero_grad()

        # 生成一批新的伪造样本
        z = torch.randn(current_batch_size, latent_dim, device=device)
        generated_for_g_loss = generator(z)

        # 生成器损失 (试图让判别器认为伪造样本是真实的)
        generator_loss = -torch.mean(critic(generated_for_g_loss))

        generator_loss.backward()
        optimizer_G.step()

    # 打印训练过程中的损失
    if (epoch + 1) % 200 == 0: # 每200轮打印一次
        print(
            f"[Epoch {epoch+1}/{epochs}] "
            f"[Critic Loss: {critic_loss.item():.4f}] "
            f"[Generator Loss: {generator_loss.item():.4f}]"
        )

print("训练完成。")

# --- 7. 生成新数据并进行后处理 ---
print("\n正在生成增强数据...")
# 设置需要生成的样本数量
# num_generated_samples = original_data_df.shape[0] # 例如，生成与原始数据同样多的样本
# 或者你可以设置一个更大的数量，例如：
num_generated_samples = 2000

generator.eval() # 将生成器设置为评估模式
with torch.no_grad():
    z_generate = torch.randn(num_generated_samples, latent_dim, device=device)
    generated_standardized_data_np = generator(z_generate).detach().cpu().numpy()

# (A) 反标准化所有生成的特征
generated_data_original_scale_np = generated_standardized_data_np * data_std + data_mean
generated_data_df = pd.DataFrame(generated_data_original_scale_np, columns=all_feature_names)

# # (B) 对 'CHO' 列进行特殊后处理：映射到最近的预设等级
# if 'CHO' in generated_data_df.columns:
#     print("正在对 'CHO' 列进行后处理...")
#     cho_column_generated = generated_data_df['CHO'].values
#     # 找到每个生成值最近的CHO等级
#     processed_cho = np.array([CHO_LEVELS[np.abs(CHO_LEVELS - val).argmin()] for val in cho_column_generated])
#     generated_data_df['CHO'] = processed_cho
#     print("'CHO' 列处理完成。")
# else:
#     print("警告: 'CHO' 列未在生成数据中找到，跳过CHO特定后处理。")


# (C) 对其他所有连续特征 (包括 'Rowing distance'，但不包括已特殊处理的 'CHO') 进行np.clip
print("正在对其他连续特征进行范围裁剪...")
for col_name in all_feature_names:
    # if col_name == 'CHO': # CHO已经特殊处理过了
    #     continue

    original_col_values = original_data_df[col_name]
    col_min_original = original_col_values.min()
    col_max_original = original_col_values.max()
    col_range = col_max_original - col_min_original

    if col_range == 0: # 如果原始列中所有值都相同
        clip_min_for_col = col_min_original
        clip_max_for_col = col_max_original
    else:
        clip_min_for_col = col_min_original - 0.01 * col_range
        clip_max_for_col = col_max_original + 0.01 * col_range
            # 确保 CHO (以及其他相关营养素) 在生理上不可能为负值的情况下不被裁剪到0以下
    if col_name in ['CHO', 'PRO']: # 如果需要，可以加入其他相关的营养素名称
        clip_min_for_col = max(0, clip_min_for_col) # 例如，CHO和PRO不应小于0
    
    # 应用裁剪
    generated_data_df[col_name] = np.clip(generated_data_df[col_name], clip_min_for_col, clip_max_for_col)
    # print(f"  列 '{col_name}' 已裁剪至范围 [{clip_min_for_col:.2f}, {clip_max_for_col:.2f}]")

print("所有其他连续特征裁剪完成。")


# --- 8. 保存生成的增强数据 ---
try:
    os.makedirs(output_folder_path, exist_ok=True) # 再次确保文件夹存在
    generated_data_df.to_excel(full_output_path, index=False)
    print(f"\n增强数据已成功保存到: {full_output_path}")
    print(f"生成数据形状: {generated_data_df.shape}")
    print("生成数据前5行:")
    print(generated_data_df.head())
except Exception as e:
    print(f"保存增强数据时发生错误: {e}")