In [None]:
import pandas as pd
import numpy as np
from scipy.stats import mannwhitneyu
from statsmodels.sandbox.stats.multicomp import multipletests
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches # 用于自定义图例
import seaborn as sns # 虽然这里主要用matplotlib，但可以借鉴其颜色或样式设置
import os

from pathlib import Path 
CURRENT_DIR = Path.cwd()
PROJECT_ROOT = CURRENT_DIR.parent
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "output"

# --- 配置参数 ---
# 文件路径 (请确保路径正确)
base_path = DATA_DIR
original_file = DATA_DIR / "development_set_selected_features.xlsx"
augmented_files_info = {
    "Mixup": DATA_DIR / "augmented_data_mixup_abs_2000.xlsx",
    "Noise Injection":DATA_DIR / "augmented_data_noise_abs_2000.xlsx",
    "WGAN-GP": DATA_DIR / "development_set_selected_features_迭代10000.xlsx"
}

# 输出图片路径
output_plot_dir = OUTPUT_DIR
output_plot_file = OUTPUT_DIR / "mwu_p_values_comparison_lancet_style.png"

# 创建输出目录 (如果不存在)
os.makedirs(output_plot_dir, exist_ok=True)

alpha = 0.05 # 显著性水平

# --- 数据加载 ---
try:
    df_original = pd.read_excel(original_file)
    print(f"原始数据 '{os.path.basename(original_file)}' 加载成功，形状: {df_original.shape}")
except FileNotFoundError as e:
    print(f"文件未找到错误: {e}")
    exit()
except Exception as e:
    print(f"加载原始数据时发生错误: {e}")
    exit()

# --- 自动识别或用户定义指标列 ---
all_original_columns = df_original.columns.tolist()
columns_to_exclude = [] # <--- *** 如有需要，在此处填入需要从原始数据中排除的列名 ***

indicator_columns = [col for col in all_original_columns if col not in columns_to_exclude]

if not indicator_columns:
    print("错误: 根据 'columns_to_exclude' 设置，没有识别出任何用于比较的指标列。")
    exit()

print(f"\n识别出的用于比较的指标列 ({len(indicator_columns)}个):")
# for i, col_name in enumerate(indicator_columns):
#     print(f"  {i+1}. {col_name}") # 如果需要，可以取消注释来打印

expected_num_indicators = 21
if len(indicator_columns) != expected_num_indicators:
    print(f"\n警告: 识别出的指标列数量为 {len(indicator_columns)}，而不是期望的 {expected_num_indicators} 个。")
    print(f"  代码将继续使用这 {len(indicator_columns)} 个识别出的指标进行分析。")

# --- 加载增强数据并检查指标列 ---
dfs_augmented = {}
try:
    for name, path in augmented_files_info.items():
        df_aug = pd.read_excel(path)
        print(f"增强数据 '{os.path.basename(path)}' ({name}) 加载成功，形状: {df_aug.shape}")
        missing_cols_augmented = [col for col in indicator_columns if col not in df_aug.columns]
        if missing_cols_augmented:
            print(f"错误: 增强数据 '{name}' 中缺失以下必要的指标列: {missing_cols_augmented}")
            exit()
        dfs_augmented[name] = df_aug
except FileNotFoundError as e:
    print(f"加载增强数据时文件未找到: {e}")
    exit()
except Exception as e:
    print(f"加载增强数据时发生错误: {e}")
    exit()

# --- MWU检验与FDR校正 ---
results = {}
for aug_name, df_aug in dfs_augmented.items():
    p_values_original = []
    current_indicators_for_test = []
    for indicator in indicator_columns:
        original_series = df_original[indicator].dropna()
        augmented_series = df_aug[indicator].dropna()
        if not original_series.empty and not augmented_series.empty:
            try:
                stat, p_val = mannwhitneyu(original_series, augmented_series, alternative='two-sided')
                p_values_original.append(p_val)
                current_indicators_for_test.append(indicator)
            except ValueError as ve:
                p_values_original.append(np.nan)
                current_indicators_for_test.append(indicator)
        else:
            p_values_original.append(np.nan)
            current_indicators_for_test.append(indicator)

    valid_p_indices = [i for i, p in enumerate(p_values_original) if not np.isnan(p)]
    valid_p_values = [p_values_original[i] for i in valid_p_indices]
    reject_null_all = [False] * len(p_values_original)
    p_values_corrected_all = [np.nan] * len(p_values_original)
    if valid_p_values:
        reject_null_valid, p_values_corrected_valid, _, _ = multipletests(
            valid_p_values, alpha=alpha, method='fdr_bh')
        for i, original_idx in enumerate(valid_p_indices):
            reject_null_all[original_idx] = reject_null_valid[i]
            p_values_corrected_all[original_idx] = p_values_corrected_valid[i]
    results[aug_name] = {
        'indicators': current_indicators_for_test,
        'p_values_original': p_values_original,
        'p_values_corrected': p_values_corrected_all,
        'significant_after_fdr': reject_null_all}
    print(f"\n--- {aug_name} vs. Original ---")
    for i, indicator_name in enumerate(results[aug_name]['indicators']):
        p_orig = results[aug_name]['p_values_original'][i]
        p_corr = results[aug_name]['p_values_corrected'][i]
        is_sig = results[aug_name]['significant_after_fdr'][i]
        if np.isnan(p_orig):
            print(f"指标: {indicator_name:<25} | 原始P值: N/A      | FDR校正P值(q值): N/A      | 是否显著: N/A")
        else:
            print(f"指标: {indicator_name:<25} | 原始P值: {p_orig:.4f} | FDR校正P值(q值): {p_corr:.4f} | 是否显著(q<{alpha}): {is_sig}")

# --- 可视化P值 (柳叶刀风格) ---
plt.style.use('seaborn-v0_8-whitegrid') # 使用seaborn样式作为基础
plt.rcParams['font.family'] = 'sans-serif'
# 尝试使用常见的无衬线字体
try:
    plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans', 'sans-serif']
except:
    print("Arial/Helvetica字体未找到，使用默认无衬线字体。")

# 定义柳叶刀风格颜色
color_significant = '#D0021B'       # 柳叶刀红 (用于显著差异)
color_not_significant = '#0072B2' # 柳叶刀蓝 (用于不显著差异, D0021B和0072B2是常见的对比色)
# color_not_significant = '#00425A' # 或者更深的青色
color_na = '#B0BEC5'              # 浅灰色 (用于N/A的P值)
text_color = '#37474F'            # 深灰色 (用于文本和标签)
grid_color = '#DCDCDC'            # 浅灰色 (用于网格线)
line_color_alpha = '#757575'      # 中灰色 (用于alpha参考线)

num_augmented_methods = len(dfs_augmented)
# 动态调整每个子图的高度，确保标签不拥挤
effective_indicator_count = len(indicator_columns) if indicator_columns else 21
fig_height_per_subplot = max(7, effective_indicator_count * 0.4) # 增加每个指标的高度占比
fig, axes = plt.subplots(nrows=num_augmented_methods, ncols=1,
                         figsize=(18, fig_height_per_subplot * num_augmented_methods + 1), # +1 为标题和底部标签留空间
                         sharey=False)

if num_augmented_methods == 1: # 处理只有一个子图的情况
    axes = [axes]

for i, aug_name in enumerate(results.keys()):
    ax = axes[i]
    data = results[aug_name]
    
    plot_indicators_names = []
    plot_p_values_original = []
    bar_colors_for_plot = []

    for idx, p_orig in enumerate(data['p_values_original']):
        indicator_name = data['indicators'][idx]
        plot_indicators_names.append(indicator_name)

        if not np.isnan(p_orig):
            plot_p_values_original.append(p_orig)
            is_sig = data['significant_after_fdr'][idx]
            bar_colors_for_plot.append(color_significant if is_sig else color_not_significant)
        else:
            plot_p_values_original.append(0) # 将N/A的P值画在0处，用特定颜色标记
            bar_colors_for_plot.append(color_na)
    
    if not plot_indicators_names:
        ax.text(0.5, 0.5, "No valid P-values to plot.", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=14, color=text_color)
        ax.set_title(f'{aug_name} vs. Original Data', fontsize=18, pad=20, color=text_color, fontweight='bold')
        continue

    x_positions = np.arange(len(plot_indicators_names))
    bars = ax.bar(x_positions, plot_p_values_original, color=bar_colors_for_plot, width=0.7)
    
    ax.axhline(y=alpha, color=line_color_alpha, linestyle='--', linewidth=1.5, label=f'Nominal $\\alpha$ = {alpha}')
    
    ax.set_title(f'{aug_name} vs. Original Data (P-values, FDR Corrected)', fontsize=18, pad=25, color=text_color, fontweight='bold') # 增加pad
    ax.set_ylabel('Original P-value', fontsize=16, color=text_color, labelpad=15) # 增加labelpad
    ax.set_ylim(0, 1.1) # 稍微增加Y轴上限给图例留空间
    
    ax.tick_params(axis='y', labelsize=15, colors=text_color)
    
    # 设置X轴刻度和标签
    ax.set_xticks(x_positions)
    ax.set_xticklabels(plot_indicators_names, rotation=50, ha="right", rotation_mode="anchor", fontsize=16, color=text_color) # 调整旋转角度
    
    # 网格线设置
    ax.grid(axis='y', linestyle=':', linewidth=0.8, color=grid_color)
    ax.grid(axis='x', visible=False) # 通常不显示垂直网格线
    ax.set_facecolor('white') # 设置子图背景为白色

    # 移除顶部和右侧的轴脊柱 (spine)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_color(grid_color) # 轴脊柱颜色与网格线协调
    ax.spines['bottom'].set_color(grid_color)

    # 创建图例
    legend_elements = [
        mpatches.Patch(facecolor=color_significant, label=f'Significant (q < {alpha})'),
        mpatches.Patch(facecolor=color_not_significant, label=f'Not Significant (q $\\geq$ {alpha})'),
        plt.Line2D([0], [0], color=line_color_alpha, linestyle='--', lw=1.5, label=f'Nominal $\\alpha$ = {alpha}')
    ]
    if color_na in bar_colors_for_plot:
         legend_elements.append(mpatches.Patch(facecolor=color_na, label='P-value N/A'))
    
    ax.legend(handles=legend_elements, loc='upper right', fontsize=16, frameon=True, facecolor='white', framealpha=0.9, edgecolor=grid_color)

if num_augmented_methods > 0 and isinstance(axes, np.ndarray) and axes.size > 0 :
    # 在最后一个子图下方添加总的X轴标签（如果需要）
    # fig.text(0.5, 0.01, 'Indicators', ha='center', va='center', fontsize=16, color=text_color) # 或者只在最下面子图加xlabel
    axes[-1].set_xlabel('Indicators', fontsize=18, color=text_color, labelpad=20)


plt.tight_layout(pad=3.0, h_pad=4.0, w_pad=3.0) # 调整整体布局
fig.subplots_adjust(top=0.95 - (0.01 * num_augmented_methods)) # 为大标题留出更多空间，防止重叠

try:
    plt.savefig(output_plot_file, dpi=300, bbox_inches='tight')
    print(f"\n图表已保存到: {output_plot_file}")
except Exception as e:
    print(f"保存图表时发生错误: {e}")

plt.show()