In [1]:
'''
训练机器学习模型，预测 28天死亡率及院内死亡率
'''
import os, sys
from datetime import datetime
import argparse
import pandas as pd
import joblib
import torch

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold 

# 检测运行环境
IN_NOTEBOOK = None
def in_notebook():
    return 'IPKernelApp' in getattr(globals().get('get_ipython', lambda: None)(), 'config', {})
    
if in_notebook():
    from IPython.display import clear_output, display
    notebook_dir = os.getcwd()
    src_path = os.path.abspath(os.path.join(notebook_dir, '..'))
    IN_NOTEBOOK = True
else:
    src_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
    parser = argparse.ArgumentParser(description='')
    sys_args = parser.parse_args()
    IN_NOTEBOOK = False

sys.path.append(src_path) if src_path not in sys.path else None

from src.utils import *
from src.model_utils import *
from src.metrix import cal_ci, format_ci
from src.setup import *
from risk_setup import *

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'current device: {DEVICE}')
risk_ml_models = f'{MODELS}/risk_models/ML_ref_model/'
os.makedirs(risk_ml_models, exist_ok=True)

current device: cpu


In [2]:
df = pd.read_csv(f'{DATA}/imputed/MIMIC_IV_clean_imputed.tsv.gz', sep='\t', index_col='ID')
df = df.sample(frac=0.1) if IN_NOTEBOOK else df
features, _, _, outcomes = get_risk_model_features()
X, y = load_data(df, outcome_ix=0) # 这里加载了 28-d mortality 作为预测目标

# load multi-task y
y = df[outcomes].copy() 

# standardization of X
std_processor = StandardScaler()
X_array = std_processor.fit_transform(X)
X = pd.DataFrame(X_array, index=X.index, columns=X.columns)
joblib.dump(std_processor, f'{risk_ml_models}/MIMIC_StandardScaler.joblib')

print(f'training data: {X.shape}')

training data: (2001, 31)


# 设计 VAE

In [3]:
import numpy as np
import torch
from torch import nn
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.metrics import mean_squared_error, roc_auc_score
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPS = 1e-10

class VAE(nn.Module):
    def __init__(self, input_dim=30, hidden_dim=64, latent_dim=10):
        super(VAE, self).__init__()
        
        # 编码器，将输入映射到潜在均值和方差
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # 解码器，从潜在空间重构输入
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim) # 输出重构特征
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        # reparameterization trick: z = mu + sigma * epsilon
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        # 前向传播
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar, z


class MultiTaskPredictor(nn.Module):
    def __init__(self, latent_dim=10, hidden_dim=32):
        super(MultiTaskPredictor, self).__init__()
        # 可以使用共享隐藏层，再输出两头
        self.shared = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 28天死亡率预测头
        self.fc_28d = nn.Linear(hidden_dim, 1)
        # 院内死亡率预测头
        self.fc_inhospital = nn.Linear(hidden_dim, 1)
        
    def forward(self, z):
        h = self.shared(z)
        p_28d = torch.sigmoid(self.fc_28d(h))
        p_inhospital = torch.sigmoid(self.fc_inhospital(h))
        return p_28d, p_inhospital


# 示例：将VAE与多任务预测相结合的整体模型封装
class VAEMultiTaskModel(nn.Module):
    def __init__(self, input_dim=30, vae_hidden_dim=64, latent_dim=10, predictor_hidden_dim=32):
        super(VAEMultiTaskModel, self).__init__()
        self.vae = VAE(input_dim, vae_hidden_dim, latent_dim)
        self.predictor = MultiTaskPredictor(latent_dim, predictor_hidden_dim)
    
    def forward(self, x):
        # VAE前向
        recon, mu, logvar, z = self.vae(x)
        # 多任务预测前向
        p_28d, p_inhospital = self.predictor(z)
        return recon, mu, logvar, p_28d, p_inhospital


Unnamed: 0_level_0,sex,age,BMI,temperature,heart_rate,respir_rate,SBP,DBP,MAP,RBC,...,K+,Na+,APTT,Fg,PH,PaO2,PaO2/FiO2,PaCO2,HCO3-,Lac
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
12973912_32337068,-1.190238,0.781027,0.831221,0.116978,-0.514078,0.061784,-0.517026,-0.222869,-0.626303,0.629040,...,-0.005689,0.257928,-0.424413,0.268958,-0.955427,-0.670732,-0.088951,0.870313,0.548828,-0.787726
12556770_36784294,0.840168,-1.024745,0.545215,-0.677973,1.333647,-1.692210,0.346394,1.261534,0.961727,1.520771,...,1.750558,-1.268269,-0.032088,-0.510566,-0.243562,-0.529380,-0.950354,-0.351791,-0.913332,2.583998
15928447_39786912,0.840168,1.443143,-0.797077,0.625225,-1.540591,-0.894940,0.864447,2.631752,2.209465,-0.105327,...,0.264503,0.257928,-0.709740,-0.013538,-0.345257,-0.688401,-0.431626,-0.179720,-0.495572,-0.243900
12065060_36076971,0.840168,-0.242244,-0.104643,-0.026374,-0.514078,-1.054394,-0.344342,0.519332,0.110997,2.176456,...,-0.410977,-0.123622,-0.657178,-0.101957,0.061524,0.089037,0.774025,0.078385,0.757708,-0.130931
12171102_35728280,-1.190238,-0.483013,-0.591302,-0.378237,0.307133,0.061784,0.001026,0.690610,0.734866,1.402748,...,0.939983,0.639477,-0.905902,-0.168775,-1.057122,-0.573552,0.474839,-0.093685,-0.913332,-0.281692
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13036533_37187494,0.840168,-0.422821,1.869740,0.051818,-0.822032,0.221237,1.296157,0.062593,0.281143,-0.616761,...,2.155846,0.639477,-0.112337,0.652724,-1.260512,-0.688401,0.676567,0.594596,-0.495572,-0.461430
15380896_31385470,-1.190238,0.480065,0.460157,0.195170,-0.052146,-0.416578,0.993960,-0.108684,0.508004,1.533885,...,-1.221552,-0.314396,-1.155563,0.925268,0.573996,-0.703177,-0.950354,0.507838,1.175468,0.354309
14658742_35548408,-1.190238,0.480065,-0.226742,-0.026374,0.769064,-1.054394,-0.560197,-0.622516,-0.626303,0.668381,...,0.129407,-1.268269,-1.119897,0.959938,0.875084,-0.140661,-0.950354,-0.695931,0.548828,-0.570196
16031267_39905629,0.840168,-1.987823,-1.501483,2.723374,0.204482,-0.576032,-0.603368,-0.451239,-0.512872,1.101133,...,-0.951360,-1.649818,-0.451162,0.302931,1.993729,0.707454,2.065607,-1.298178,0.548828,-0.733344
