In [1]:
# -*- coding: utf-8 -*-
import os
import numpy as np
import pandas as pd
import shap
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool
import matplotlib
matplotlib.use('Agg')  # 非GUI后端，适合服务器或保存图像
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

# 设置字体
plt.rcParams['font.sans-serif'] = ['Times New Roman']
plt.rcParams['axes.unicode_minus'] = False

# 定义保存路径
save_dir = 'M2_compare/SHAP/'
os.makedirs(save_dir, exist_ok=True)

# SHAP绘图函数
def shap_summary_plot(shap_values):
    plt.clf()
    shap.plots.beeswarm(shap_values, max_display=27, show=False, color=plt.get_cmap("cool"))
    fig = plt.gcf()
    fig.set_size_inches(6, 10)
    plt.title('SHAP Feature Importance')
    plt.grid(False)
    plt.savefig(os.path.join(save_dir, 'shap_summary_plot.png'), dpi=300, bbox_inches='tight')

# 模型训练、SHAP分析和文件保存函数
def model_shap_train_vali_test_category_balance(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)
    model = CatBoostClassifier(iterations=500, learning_rate=0.01, verbose=0)
    model.fit(Pool(X_train, y_train))

    explainer = shap.Explainer(model)
    shap_values = explainer(X)

    # 保存 SHAP 值
    pd.DataFrame(shap_values.values, columns=X.columns).to_csv(os.path.join(save_dir, 'shap_values.csv'), index=False)
    pd.DataFrame(shap_values.data, columns=X.columns).to_csv(os.path.join(save_dir, 'shap_data.csv'), index=False)

    # 特征重要性
    mean_abs_shap = np.abs(shap_values.values).mean(axis=0)
    shap_df = pd.DataFrame({'feature name': X.columns, 'importance': mean_abs_shap})
    shap_df = shap_df.sort_values(by='importance', ascending=False)
    shap_df.to_csv(os.path.join(save_dir, 'shap_importance.csv'), index=False)

    # 画图
    shap_summary_plot(shap_values)

# 读取数据
rawdata = pd.read_excel('dataset/OHCA_M2.xlsx')
rawdata = rawdata.sample(frac=1, random_state=42)

# 选择目标列与特征列
datay = rawdata['Patient outcome (M2)']
x_features =['age','Bystander use of AEDs','Time to ambulance arrival','Performer of defibrillation_Medical staff',
'Bystander CPR','Initial rhythm of cardiac arrest_Normal heart rhythm',
'Initial rhythm of cardiac arrest_Shockable heart rhythm','Initial rhythm of cardiac arrest_Non-shockable heart rhythm',
'Out-of-hospital electrical defibrillation','Location_Family house','5-minute social rescue circle','Training rate_6 months','Number of AEDs within 75m',
'Use of electrical defibrillation in ED','Use of mechanical CPR device in ED', 'Establishment of advanced artificial airway in ED',
'Use of medications in ED','PCI for ED','TTM for ED','ECMO for ED','Return of spontaneous circulation in ED' 
]

datax = rawdata[x_features]

# 执行模型和SHAP分析
model_shap_train_vali_test_category_balance(datax, datay)
