In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os


from pathlib import Path 
CURRENT_DIR = Path.cwd()
PROJECT_ROOT = CURRENT_DIR.parent
DATA_DIR = PROJECT_ROOT / "data"
OUTPUT_DIR = PROJECT_ROOT / "output"



# --- Configuration Parameters ---
# File paths (ensure these are correct)
base_path = DATA_DIR
original_file = DATA_DIR / "development_set_selected_features.xlsx"
augmented_files_info = {
    "Mixup": DATA_DIR / "augmented_data_mixup_abs_2000.xlsx",
    "Noise Injection": DATA_DIR / "augmented_data_noise_abs_2000.xlsx",
    "WGAN-GP": DATA_DIR / "development_set_selected_features_迭代10000.xlsx"}

# Output plot path
output_plot_dir = OUTPUT_DIR
output_plot_file = OUTPUT_DIR / "correlation_matrix_comparison.png"

# Create output directory if it doesn't exist
os.makedirs(output_plot_dir, exist_ok=True)

# --- Load Original Data ---
try:
    df_original = pd.read_excel(original_file)
    print(f"Original data '{os.path.basename(original_file)}' loaded successfully. Shape: {df_original.shape}")
except FileNotFoundError:
    print(f"ERROR: Original data file not found at {original_file}")
    exit()
except Exception as e:
    print(f"Error loading original data: {e}")
    exit()

# --- Determine Indicator Columns (same logic as your MWU script) ---
all_original_columns = df_original.columns.tolist()

# !!! USER ACTION: Specify columns to exclude from 'data1.xlsx' if they are not part of the 21 indicators !!!
# Example: columns_to_exclude = ['Subject_ID', 'Timepoint']
columns_to_exclude = [] # Keep empty if all columns in data1.xlsx (after potential prior exclusions) are indicators

indicator_columns = [col for col in all_original_columns if col not in columns_to_exclude]

if not indicator_columns:
    print("ERROR: No indicator columns identified. Please check 'data1.xlsx' and 'columns_to_exclude'.")
    exit()

print(f"\nUsing {len(indicator_columns)} indicator columns for correlation analysis:")
# for col_name in indicator_columns: print(f"  - {col_name}") # Uncomment to list them

# Optional: Check if the number of indicators is 21, as per your previous context
# expected_num_indicators = 21
# if len(indicator_columns) != expected_num_indicators:
#     print(f"\nWARNING: Identified {len(indicator_columns)} indicators, but expected {expected_num_indicators}.")
#     print("Proceeding with the identified columns.")

# --- Load Augmented Data ---
dfs_augmented = {}
for name, path in augmented_files_info.items():
    try:
        df_aug = pd.read_excel(path)
        print(f"Augmented data '{os.path.basename(path)}' ({name}) loaded successfully. Shape: {df_aug.shape}")
        
        # Check if all indicator columns are present in the augmented dataframe
        missing_cols = [col for col in indicator_columns if col not in df_aug.columns]
        if missing_cols:
            print(f"ERROR: Augmented data '{name}' is missing the following indicator columns: {missing_cols}")
            exit()
        dfs_augmented[name] = df_aug[indicator_columns] # Keep only indicator columns
    except FileNotFoundError:
        print(f"ERROR: Augmented data file not found at {path}")
        exit()
    except Exception as e:
        print(f"Error loading augmented data {name}: {e}")
        exit()

# Prepare data list for plotting (Original + Augmented)
datasets_to_plot = [("Original Data", df_original[indicator_columns])]
for name, df_aug in dfs_augmented.items():
    datasets_to_plot.append((f"{name} Augmented", df_aug))

# --- 1. 先计算 vmin/vmax 和统一的 cmap -----------------
vmin, vmax = -1, 1
cmap = sns.diverging_palette(250, 15, s=80, l=55, n=9, center="light", as_cmap=True)

# --- 2. 建图：加宽 figsize 给 colorbar 留位置 ----------
fig, axes = plt.subplots(2, 2, figsize=(20, 16))   # 比原来宽
axes = axes.flatten()

# --- 3. 画热图，全部关掉 colorbar -----------------------
for i, (title, df_data) in enumerate(datasets_to_plot):
    if i >= len(axes):
        break
    corr = df_data.corr()
    mask = np.triu(np.ones_like(corr, bool))
    sns.heatmap(
        corr,
        ax=axes[i],
        mask=mask,
        cmap=cmap,
        vmin=vmin, vmax=vmax, center=0,
        square=True,
        linewidths=.3,
        cbar=False,                 # 关键：不画子图 colorbar
        annot=False,
        fmt=".2f"
    )
    axes[i].set_title(title, fontsize=18, pad=15, fontweight='bold')

    # 调大刻度文字
    tick_font = 12
    axes[i].tick_params(axis='x', rotation=90, labelsize=tick_font)
    axes[i].tick_params(axis='y', rotation=0,  labelsize=tick_font)

# --- 4. 添加全局 colorbar -----------------------------
# 在右侧留 5% 宽度放 colorbar
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])   # [left, bottom, width, height]
cbar = fig.colorbar(
    plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin, vmax)),
    cax=cbar_ax,
    orientation='vertical',
    label='Correlation Coefficient'
)
cbar.ax.tick_params(labelsize=12)   # colorbar 刻度字号
cbar.set_label('Correlation Coefficient', fontsize=14)

# --- 5. 隐藏多余子图 & 保存 ---------------------------
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

fig.suptitle('Comparison of Feature Correlation Matrices',
             fontsize=22, fontweight='bold', y=0.98)
plt.tight_layout(rect=[0, 0, 0.9, 1])   # 右侧给 colorbar 留空
plt.savefig(output_plot_file, dpi=300, bbox_inches='tight')
plt.show()