In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

from pathlib import Path 
CURRENT_DIR = Path.cwd()
PROJECT_ROOT = CURRENT_DIR.parent
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "output"



# --- 用户配置 ---
# 定义文件路径
data_folder = DATA_DIR
output_folder = OUTPUT_DIR

# 数据集文件名和标签
datasets_info = {
    "Original": DATA_DIR / "development_set_selected_features.xlsx",
    "Mixup": DATA_DIR / "augmented_data_mixup_abs_2000.xlsx",
    "NoiseInjection": DATA_DIR / "augmented_data_noise_abs_2000.xlsx",
    "WGAN-GP": DATA_DIR / "development_set_selected_features_迭代10000.xlsx"
}

# 输出图像设置
output_filename = "all_datasets_kde_comparison.png"
dpi = 300
num_plot_cols = 4 # 每行显示多少个特征图，可以根据特征数量调整
# --- 用户配置结束 ---

# 确保输出文件夹存在
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
    print(f"Created output directory: {output_folder}")

# 读取数据
data_frames = {}
first_dataset_name = list(datasets_info.keys())[0]

try:
    for name, path in datasets_info.items():
        print(f"Loading dataset: {name} from {path}")
        data_frames[name] = pd.read_excel(path)
        print(f"Successfully loaded {name}. Shape: {data_frames[name].shape}")
    
    # 以第一个数据集（通常是原始数据集）的列为准
    if not data_frames:
        raise ValueError("No datasets were loaded.")
    
    reference_df = data_frames[first_dataset_name]
    data_columns = reference_df.columns
    num_features = len(data_columns)
    print(f"Using columns from '{first_dataset_name}': {data_columns.tolist()}")

except FileNotFoundError as e:
    print(f"Error: File not found. Please check your file paths. Details: {e}")
    exit()
except ValueError as e:
    print(f"Error: {e}")
    exit()
except Exception as e:
    print(f"An unexpected error occurred during data loading: {e}")
    exit()


# 计算子图的行数和列数
num_plot_rows = (num_features + num_plot_cols - 1) // num_plot_cols

# 创建子图网格
fig, axes = plt.subplots(num_plot_rows, num_plot_cols, figsize=(6 * num_plot_cols, 5 * num_plot_rows), squeeze=False)
axes = axes.flatten() # 将二维数组展平，方便索引

# # 定义配色 (可以根据需要选择更多颜色)
# Using a seaborn palette for distinct colors
colors = sns.color_palette("husl", n_colors=len(datasets_info))
dataset_colors = {name: colors[i] for i, name in enumerate(datasets_info.keys())}



print(f"\nStarting plot generation for {num_features} features...")
TITLE_SIZE  = 20
LABEL_SIZE  = 14

LEGEND_SIZE = 20

for i, col in enumerate(data_columns):
    ax = axes[i]
    ax.set_title(col, fontsize=TITLE_SIZE)
    
    # 检查列是否存在于所有数据帧中
    missing_in = [name for name, df in data_frames.items() if col not in df.columns]
    if missing_in:
        print(f"Warning: Column '{col}' missing in datasets: {', '.join(missing_in)}. Skipping this column for them.")

    for name, df in data_frames.items():
        if col in df.columns:
            try:
                sns.kdeplot(df[col].dropna(), ax=ax, label=name, color=dataset_colors[name], fill=True, alpha=0.2, linewidth=1.5)
            except Exception as e:
                print(f"Could not plot KDE for {name} - {col}. Error: {e}. Skipping.")
        
    ax.set_xlabel("Value",fontsize=LABEL_SIZE)
    ax.set_ylabel("Density", fontsize=LABEL_SIZE)
    # if i == 0: # Add legend only to the first plot or manage a figure-level legend
    #     ax.legend(loc='upper right', fontsize=LEGEND_SIZE)

# ---------- 2. 在循环里先收集一次 handles/labels ----------
handles, labels = None, None
for i, col in enumerate(data_columns):
    ax = axes[i]
    ...  # 你的 kdeplot 代码

    # 只在第一张图里收集一次 legend 信息
    if handles is None and labels is None:
        handles, labels = ax.get_legend_handles_labels()

# ---------- 3. 用 fig.legend 放到整幅图的右下角 ----------
fig.legend(handles, labels,
           loc='lower right',   # 整个 figure 的右下角
           bbox_to_anchor=(0.98, 0.02),  # 微调位置
           fontsize=LEGEND_SIZE,
           frameon=False)

# ---------- 4. 把所有多余子图轴删掉（包括右下角那个） ----------
for j in range(num_features, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout(pad=2.0) # Add some padding

# 保存图片
output_file_path = os.path.join(output_folder, output_filename)
try:
    fig.savefig(output_file_path, dpi=dpi, format='png', bbox_inches='tight')
    print(f"\nSuccessfully saved comparison plot to: {output_file_path}")
except Exception as e:
    print(f"Error saving plot: {e}")

# plt.show() #取消注释以在脚本执行后显示绘图

print("Script finished.")