# `auton-survival` 交叉验证生存回归

`auton-survival` 提供了一个简单易用的 API 来训练生存回归模型，通过最小化综合 Brier 分数进行交叉验证模型选择。在本笔记本中，我们演示如何使用 `auton-survival` 以交叉验证方式在 *SUPPORT* 数据集上训练生存模型。


In [None]:
# 将 sys 模块导入，用于访问与 Python 解释器相关的变量和函数
import sys

# 将当前脚本的上级目录（'../'）添加到模块搜索路径，方便导入自定义包
sys.path.append('../')

# 从 auton_survival 包中导入 datasets 子模块，该模块通常封装了生存分析相关的示例数据集
from auton_survival import datasets

# 调用 datasets.load_support() 加载 SUPPORT 数据集，返回两个对象：
# outcomes: 生存时间、事件指示等标签数据
# features: 协变量/特征数据
outcomes, features = datasets.load_support()

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
outcomes.head()

Unnamed: 0,event,time
0,0,2029
1,1,4
2,1,47
3,1,133
4,0,2029


In [None]:
# 导入 auton_survival 库中的 Preprocessor 类，用于数据预处理
from auton_survival.preprocessing import Preprocessor

# 定义分类特征（cat_feats）列表：这些列是类别型变量，后续会被特殊处理
cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']

# 定义数值特征（num_feats）列表：这些列是数值型变量，后续会被特殊处理
num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 
             'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 
             'glucose', 'bun', 'urine', 'adlp', 'adls']

# 注意：在做交叉验证时，数据预处理应当与折（fold）独立，以避免信息泄漏。
# 本示例为了简化，直接在整体数据上拟合预处理，实际生产环境请按折独立方式处理。
preprocessor = Preprocessor(cat_feat_strat='ignore',   # 对缺失的分类特征采用“忽略”策略（即不处理）
                            num_feat_strat='mean')     # 对缺失的数值特征用该列均值填充

# 使用定义好的策略对特征进行拟合与转换：
# - features：原始 DataFrame
# - cat_feats / num_feats：分别指定哪些列是分类/数值
# - one_hot=True：对分类变量做独热编码
# - fill_value=-1：若仍有缺失，统一用 -1 填充
x = preprocessor.fit_transform(features, 
                               cat_feats=cat_feats, 
                               num_feats=num_feats,
                               one_hot=True, 
                               fill_value=-1)


In [12]:
x.head()


Unnamed: 0,age,num.co,meanbp,wblc,hrt,resp,temp,pafi,alb,bili,...,dzclass_Coma,income_$25-$50k,income_>$50k,income_under $11k,race_black,race_hispanic,race_other,race_white,ca_no,ca_yes
0,0.012772,-1.390013,0.449837,-0.693182,-0.892283,-0.138967,-0.881504,1.569019,-1.655686,-0.5238337,...,0,0,0,0,0,0,1,0,0,0
1,-0.148262,0.097711,-1.500702,0.51871,0.470382,1.114591,-2.005013,-1.495658,-6.389701e-16,9.880260000000001e-17,...,0,0,0,0,0,0,0,1,1,0
2,-0.635153,0.097711,-0.525432,-0.420176,-0.290175,0.487812,0.235766,-0.0831988,-6.389701e-16,-0.0789274,...,0,0,0,1,0,0,0,1,1,0
3,-1.299688,0.097711,-0.344827,-0.354697,-0.290175,0.905665,-1.680444,-3.003564e-16,-6.389701e-16,9.880260000000001e-17,...,0,0,0,1,0,0,0,1,0,0
4,1.105258,-0.646151,-0.922764,0.125837,0.470382,-0.347893,0.635237,-0.699767,-6.389701e-16,9.880260000000001e-17,...,0,0,0,0,0,0,0,1,1,0


In [None]:
import numpy as np                       # 导入 NumPy 库，用于数值计算
# 计算 outcomes 中 event 为 1 的样本对应的 time 列的四分位数（25%、50%、75%）
# np.quantile 返回 ndarray，再用 tolist() 转成 Python 列表
times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()

In [13]:
times

[14.0, 58.0, 252.0]

In [19]:
# 从 auton_survival 实验模块导入 SurvivalRegressionCV 类，用于交叉验证生存回归
from auton_survival.experiments import SurvivalRegressionCV

# 定义超参数网格：字典中每个键为模型参数名，值为待搜索的列表
# k：聚类或分组件数量，这里只尝试 3 个
# distribution：生存分布类型，这里固定为 Weibull
# learning_rate：学习率，分别尝试 0.0001 和 0.001
# layers：神经网络隐藏层结构，这里只尝试单隐藏层 100 个神经元
param_grid = {'k' : [3],
              'distribution' : ['Weibull'],
              'learning_rate' : [1e-4, 1e-3],
              'layers' : [[100]]}

# 实例化 SurvivalRegressionCV，指定模型类型为 'dsm'（Deep Survival Machines）
# 使用 3 折交叉验证，传入上述超参数网格，设置随机种子保证结果可复现
experiment = SurvivalRegressionCV(model='dsm', num_folds=5, hyperparam_grid=param_grid, random_seed=0)

# 执行交叉验证拟合：传入特征矩阵 x、生存结局 outcomes、观察时间 times
# 并以 'brs'（Brier Score 的某种变体）作为评估指标，返回最优模型
model = experiment.fit(x, outcomes, times, metric='brs')

At hyper-param {'distribution': 'Weibull', 'k': 3, 'layers': [100], 'learning_rate': 0.0001}
At fold: 0


 14%|█▎        | 1374/10000 [00:03<00:19, 441.01it/s]
100%|██████████| 50/50 [00:10<00:00,  4.69it/s]


At fold: 1


 14%|█▍        | 1420/10000 [00:03<00:18, 456.01it/s]
100%|██████████| 50/50 [00:09<00:00,  5.11it/s]


At fold: 2


 18%|█▊        | 1816/10000 [00:04<00:20, 404.20it/s]
100%|██████████| 50/50 [00:10<00:00,  4.73it/s]


At fold: 3


 18%|█▊        | 1754/10000 [00:04<00:19, 427.91it/s]
100%|██████████| 50/50 [00:09<00:00,  5.40it/s]


At fold: 4


 18%|█▊        | 1835/10000 [00:03<00:16, 481.74it/s]
100%|██████████| 50/50 [00:11<00:00,  4.49it/s]


At hyper-param {'distribution': 'Weibull', 'k': 3, 'layers': [100], 'learning_rate': 0.001}
At fold: 0


 14%|█▎        | 1374/10000 [00:03<00:23, 360.60it/s]
100%|██████████| 50/50 [00:10<00:00,  4.80it/s]


At fold: 1


 14%|█▍        | 1420/10000 [00:03<00:19, 449.84it/s]
100%|██████████| 50/50 [00:08<00:00,  5.60it/s]


At fold: 2


 18%|█▊        | 1816/10000 [00:03<00:16, 488.01it/s]
100%|██████████| 50/50 [00:09<00:00,  5.30it/s]


At fold: 3


 18%|█▊        | 1754/10000 [00:03<00:18, 456.67it/s]
100%|██████████| 50/50 [00:11<00:00,  4.35it/s]


At fold: 4


 18%|█▊        | 1835/10000 [00:04<00:19, 420.91it/s]
100%|██████████| 50/50 [00:10<00:00,  4.84it/s]
 19%|█▉        | 1886/10000 [00:05<00:22, 363.20it/s]
100%|██████████| 50/50 [00:13<00:00,  3.77it/s]


In [20]:
# 打印 experiment 对象中的 folds 属性，通常用于查看交叉验证的折数信息
print(experiment.folds)

# 单独写出 model 变量，可能是为了查看模型对象本身，或供后续交互式环境（如 Jupyter）自动显示其信息
model

[4 4 0 ... 0 1 0]


<auton_survival.estimators.SurvivalModel at 0x169f336eb60>

In [21]:
# 调用模型对象的 predict_risk 方法，输入特征 x 和待评估时间点 times，返回对应时刻的累积风险（cumulative hazard）估计值
out_risk = model.predict_risk(x, times)

# 调用模型对象的 predict_survival 方法，输入同样的特征 x 和时间点 times，返回对应时刻的生存概率（survival probability）估计值
out_survival = model.predict_survival(x, times)

In [22]:
out_risk

array([[0.1603819 , 0.33961486, 0.63719424],
       [0.29931735, 0.43278468, 0.60073989],
       [0.0994725 , 0.1826067 , 0.32752895],
       ...,
       [0.11510564, 0.22157763, 0.4087761 ],
       [0.18464419, 0.34439227, 0.59126821],
       [0.09926894, 0.1958434 , 0.37244221]])

In [23]:
out_survival

array([[0.8396181 , 0.66038514, 0.36280576],
       [0.70068265, 0.56721532, 0.39926011],
       [0.9005275 , 0.8173933 , 0.67247105],
       ...,
       [0.88489436, 0.77842237, 0.5912239 ],
       [0.81535581, 0.65560773, 0.40873179],
       [0.90073106, 0.8041566 , 0.62755779]])

In [24]:
set(experiment.folds)

{0, 1, 2, 3, 4}

In [25]:
# 导入 auton_survival 库中的 survival_regression_metric 函数，用于计算生存回归模型的评估指标
from auton_survival.metrics import survival_regression_metric

# 遍历实验数据中所有不重复的折（fold）编号
for fold in set(experiment.folds):
    # 计算当前折对应的 Brier 分数（brs），用于衡量生存预测准确性
    # 参数说明：
    #   'brs'：指定计算 Brier 分数（Brier Score）
    #   outcomes[experiment.folds==fold]：当前折的真实生存结果（事件指示器与时间）
    #   out_survival[experiment.folds==fold]：当前折模型预测的生存函数
    #   times=times：评估指标所用的时间点列表
    print(survival_regression_metric('brs', outcomes[experiment.folds==fold], 
                                     out_survival[experiment.folds==fold], 
                                     times=times))

[0.13190191 0.19317956 0.20526357]
[0.12441962 0.19222479 0.20547373]
[0.12647612 0.19651744 0.20817839]
[0.12085242 0.19009112 0.20798935]
[0.1258191  0.19258311 0.20874244]




In [26]:
# 从 auton_survival 库中导入生存回归评估指标函数
from auton_survival.metrics import survival_regression_metric

# 遍历实验对象中所有不重复的折（fold）编号
for fold in set(experiment.folds):
    # 计算并打印当前折的 C-index (concordance index) 指标
    # 'ctd' 表示使用生存回归的 C-index 指标
    # outcomes[experiment.folds==fold] 筛选当前折对应的真实生存结果
    # out_survival[experiment.folds==fold] 筛选当前折对应的模型预测生存曲线
    # times 为评估时间点列表
    print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], 
                                     out_survival[experiment.folds==fold], 
                                     times=times))



[0.7563622711697736, 0.7250398040088132, 0.68945833187698]
[0.7877458728307729, 0.7248672946319378, 0.6919668920842937]
[0.7823673722061515, 0.7275873198582368, 0.687387408911719]




[0.7643429002061838, 0.7226409386190243, 0.6826610981775124]
[0.7551788099568923, 0.7127483542642455, 0.6799484091211137]


In [None]:
# 遍历实验对象中所有不重复的折（fold）值，避免重复处理同一折
for fold in set(experiment.folds):
    # 对当前折，遍历所有时间点
    for time in times:
        # 打印当前时间点，用于调试或记录
        print(time)

14.0
58.0
252.0
14.0
58.0
252.0
14.0
58.0
252.0
