In [1]:
from model_clf.model_param_grid import ModelsClassifier
import pandas as pd
from sklearn.datasets import load_breast_cancer

In [2]:
import os
os.chdir('D:/download/model_test')
# 加载示例数据
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)

# 初始化模型分类器
clf = ModelsClassifier(
    model_name="XGBoost",  # 可换成你配置 JSON 中的模型名
    X=X,
    y=y,
    data_name="breast_cancer_demo",
    device="cpu"  # 或 "cuda" 取决于是否用GPU
)



In [3]:
# 训练模型并输出结果
results = clf.train_model()

# 绘制SHAP图
outdir="./shap_output"
clf.shap_plot(outdir=outdir)  # 保存到当前路径的 shap_output 目录

# # 保存模型
clf.save_model("rf_model.pkl")

Start RandomizedSearchCV: XGBoost
Fitting 5 folds for each of 30 candidates, totalling 150 fits
✅ 最优参数组合： {'subsample': 1.0, 'reg_lambda': 1.0, 'reg_alpha': 0.0, 'n_estimators': 300, 'max_depth': 3, 'learning_rate': 0.1, 'colsample_bytree': 0.6}
📈 最优 AUC: 0.9941176470588236
✅ XGBoost 最佳ROC AUC: 0.9950 (95% CI: 0.9860 - 1.0000)
   🎯 ACC: 0.9561 | Precision: 0.9467 | Recall: 0.9861 | F1: 0.9660
   📊 混淆矩阵 (Confusion Matrix):
           Pred 0    Pred 1    
   True 0  38        4         
   True 1  1         71        
X_test type is: <class 'pandas.core.frame.DataFrame'> True
🧠 已保存 SHAP summary plot 至: ./shap_output\XGBoost_breast_cancer_demo_shap_summary.pdf
✅ 模型已保存至: rf_model.pkl


In [4]:
print(clf.help())


ModelsClassifier 类帮助文档：

初始化方法：
    ModelsClassifier(model_name, X, y, data_name, test_size=0.2, random_seed=42, cv_n=5, device='cuda')
        - model_name: 模型名称(在 JSON 配置中定义)
        - X: 特征 DataFrame
        - y: 标签 Series
        - data_name: 用于命名输出的字符串
        - test_size: 测试集占比(默认 0.2)
        - random_seed: 随机种子(默认 42)
        - cv_n: 交叉验证折数(默认 5)
        - device: 设备选择('cuda' 或 'cpu')
        - X_test/y_test: 验证集/测试集,用于检验模型性能和绘制roc曲线

方法：
    train_model()
        - 执行模型训练、交叉验证、评估、保存结果至 self.results 并打印摘要

    shap_plot(plot_type='bar', shap_plot_file=None, force=False)
        - 生成 SHAP 可解释性图并保存 PDF/CSV/Pickle
        - 参数:plot_type 可选 'bar' 或 'dot'
        - shap_plot_file 可指定输出路径
        - force=True 时强制重新绘图

    search_best(save=False, model_file=None)
        - 执行超参数搜索，返回搜索器对象 self.searcher
        - 参数:save=True 时保存模型搜索器至文件

    compute_roc_auc_ci(y_true, y_scores, n_bootstraps=1000, alpha=0.95)
        - 计算 ROC AUC 的 bootstrap 置信区间
    
    roc_plot(roc_plot_file=None, 

In [5]:
clf.save_results("rf_result.pkl")

✅ results已保存至: rf_result.pkl


In [6]:
results = clf.load_results("rf_result.pkl")

📦 results已从 rf_result.pkl 加载。


In [7]:
X

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension
0,17.99,10.38,122.80,1001.0,0.11840,0.27760,0.30010,0.14710,0.2419,0.07871,...,25.380,17.33,184.60,2019.0,0.16220,0.66560,0.7119,0.2654,0.4601,0.11890
1,20.57,17.77,132.90,1326.0,0.08474,0.07864,0.08690,0.07017,0.1812,0.05667,...,24.990,23.41,158.80,1956.0,0.12380,0.18660,0.2416,0.1860,0.2750,0.08902
2,19.69,21.25,130.00,1203.0,0.10960,0.15990,0.19740,0.12790,0.2069,0.05999,...,23.570,25.53,152.50,1709.0,0.14440,0.42450,0.4504,0.2430,0.3613,0.08758
3,11.42,20.38,77.58,386.1,0.14250,0.28390,0.24140,0.10520,0.2597,0.09744,...,14.910,26.50,98.87,567.7,0.20980,0.86630,0.6869,0.2575,0.6638,0.17300
4,20.29,14.34,135.10,1297.0,0.10030,0.13280,0.19800,0.10430,0.1809,0.05883,...,22.540,16.67,152.20,1575.0,0.13740,0.20500,0.4000,0.1625,0.2364,0.07678
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
564,21.56,22.39,142.00,1479.0,0.11100,0.11590,0.24390,0.13890,0.1726,0.05623,...,25.450,26.40,166.10,2027.0,0.14100,0.21130,0.4107,0.2216,0.2060,0.07115
565,20.13,28.25,131.20,1261.0,0.09780,0.10340,0.14400,0.09791,0.1752,0.05533,...,23.690,38.25,155.00,1731.0,0.11660,0.19220,0.3215,0.1628,0.2572,0.06637
566,16.60,28.08,108.30,858.1,0.08455,0.10230,0.09251,0.05302,0.1590,0.05648,...,18.980,34.12,126.70,1124.0,0.11390,0.30940,0.3403,0.1418,0.2218,0.07820
567,20.60,29.33,140.10,1265.0,0.11780,0.27700,0.35140,0.15200,0.2397,0.07016,...,25.740,39.42,184.60,1821.0,0.16500,0.86810,0.9387,0.2650,0.4087,0.12400


In [8]:
clf.roc_plot()

📉 ROC 曲线已保存至: ./XGBoost_breast_cancer_demo_roc_curve.pdf


In [9]:
results.keys()

dict_keys(['X_train', 'X_test', 'y_train', 'y_test', 'searcher', 'best_model', 'y_proba', 'y_pred', 'acc', 'precision', 'recall', 'f1', 'auc', 'fpr', 'tpr', 'auc_ci', 'confusion_matrix', 'X', 'y'])

In [10]:
clf.print_result()

✅ XGBoost 最佳ROC AUC: 0.9950 (95% CI: 0.9860 - 1.0000)
   🎯 ACC: 0.9561 | Precision: 0.9467 | Recall: 0.9861 | F1: 0.9660
   📊 混淆矩阵 (Confusion Matrix):
           Pred 0    Pred 1    
   True 0  38        4         
   True 1  1         71        
