## 仅供交叉验证 前馈神经网络（NNAR）-按总量分类

In [1]:
import pickle
import numpy as np
import random
from utils import *
import tensorflow.keras as keras
from tensorflow.keras import regularizers
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import StratifiedKFold
import tensorflow as tf

np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)
n_input = 11

读取数据

In [2]:
# gene_arr_path = r'../output/gene_editing/es_with_decay.array'
# transplant_arr_path = r'../output/transplant/es_with_decay.array'

# gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
# transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 17, 10)
Shape of the transplant array: (5141, 17, 10)


### 截断数据
2019年为无效数据

In [3]:
# gene_arr = gene_arr[:, :-1, :]
# transplant_arr = transplant_arr[:, :-1, :]

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)


### 规范数据并获取5折交叉检验所需的训练集和验证集

In [4]:
# scaler, data = scale_data(transplant_arr, 'standard')

# # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
# X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2],transplant_arr[:, n_input, -1]
# kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

### 按总量切分数据

In [3]:
def split_data_by_es(data, targets):
    total_es = np.sum(data[:, :11, -2], axis=1)
    sorted_index = np.argsort(total_es)
    group_size = len(total_es) // 3
    
    data1, target1 = data[sorted_index[:group_size]], targets[sorted_index[:group_size]]
    data2, target2 = data[sorted_index[group_size:2*group_size]], targets[sorted_index[group_size:2*group_size]]
    data3, target3 = data[sorted_index[2*group_size:]], targets[sorted_index[2*group_size:]]
    
    return data1, target1, data2, target2, data3, target3

### 构建模型

In [4]:
def root_mean_squared_error(y_true, y_pred):
        return keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_pred - y_true), axis=-1)) 

def build_direct_dnn_model():
    model = keras.models.Sequential()
    model.add(Flatten())
    model.add(Dense(256, activation='relu'))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(5))
    
    optimizer=keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(loss=root_mean_squared_error, optimizer=optimizer)
    return model

### 进行训练和评估
使用EarlyStopping和Checkpoint做训练停止方式

In [5]:
def cross_validation(X, y, y_cat, kfold, scaler):
    overall_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[],
        'mape':[],
        'r2':[],
        'pearson':[],
        'acc':[]
    }

    annual_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[],
        'mape':[],
        'r2':[],
        'pearson':[],
        'acc':[]
    }
    
    for train, test in kfold.split(X, y_cat):
        X_train = X[train]
        y_train = y[train]
        X_test = X[test]
        y_test = y[test]
        models = []
        
        # 按总量划分数据集
        X_train1, y_train1, X_train2, y_train2, X_train3, y_train3 = split_data_by_es(X_train, y_train)
        train_xs = [X_train1, X_train2, X_train3]
        train_ys = [y_train1, y_train2, y_train3]
        
        X_test1, y_test1, X_test2, y_test2, X_test3, y_test3 = split_data_by_es(X_test, y_test)
        test_xs = [X_test1, X_test2, X_test3]
        test_ys = [y_test1, y_test2, y_test3]
        i_s = [1, 2, 3]
        
        # 训练
        for i in range(len(i_s)):
            model = build_direct_dnn_model()
            history = model.fit(train_xs[i], train_ys[i], epochs=100, batch_size=16, verbose=1, validation_data=(test_xs[i], test_ys[i]),
                           callbacks=[
                               EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='auto', restore_best_weights=True)
                           ])
            models.append(model)
        
        # 预测
        y_test = []
        y_pred = []
        for i in range(len(i_s)):
            y_test.append(test_ys[i])
            y_pred.append(models[i].predict(test_xs[i]).reshape(test_ys[i].shape))
        
        y_test = np.concatenate(y_test)
        y_pred = np.concatenate(y_pred)

        metrics = ['mae', 'rmse','ndcg', 'mape', 'r2', 'pearson', 'acc']
        for m in metrics:
            overall, annual = eval_model(m, y_test, y_pred, scaler)
            overall_metrics[m].append(overall)
            annual_metrics[m].append(annual)
    
    return overall_metrics, annual_metrics

In [6]:
def full_pipeline():
    gene_arr_path = r'../output/gene_editing/es_with_decay.array'
    transplant_arr_path = r'../output/transplant/es_with_decay.array'

    gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
    transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))
    
    gene_arr = gene_arr[:, :-1, :]
    transplant_arr = transplant_arr[:, :-1, :]

    print('Shape of the gene_editing array:',gene_arr.shape)
    print('Shape of the transplant array:',transplant_arr.shape)
    
    metrics = {
        'gene':{
            'overall':{},
            'annual':{}
        },
        'transplant':{
            'overall':{},
            'annual':{}
        }
    }
    
    for name, dataset in zip(['gene', 'transplant'], [gene_arr, transplant_arr]):
        scaler, data = scale_data(dataset, 'standard')

        # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
        X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2], dataset[:, n_input, -1]
        kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        overall_metrics, annual_metrics = cross_validation(X, y, y_cat, kfold, scaler)
        
        for metric, value in overall_metrics.items():
            metrics[name]['overall'][metric] = np.mean(value)
        
        for metric, value in annual_metrics.items():
            metrics[name]['annual'][metric] = np.mean(np.array(value), axis=0)
    
    pickle.dump(metrics, open('mlp_metrics.dict', 'wb'))
    
    return metrics

In [7]:
metrics = full_pipeline()

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)
Train on 704 samples, validate on 177 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Train on 704 samples, validate on 177 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Train on 704 samples, validate on 177 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Ep

Epoch 51/100
Epoch 52/100
Train on 704 samples, validate on 176 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Train on 704 samples, validate on 176 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Train on 705 samples, validate on 178 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Train

Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Train on 705 samples, validate on 176 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Train on 705 samples, validate on 176 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Train on 705 samples, validate on 175 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100


Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Train on 705 samples, validate on 175 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Train on 706 samples, validate on 177 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Train on 705 samples, validate on 175 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Train on 705 samples, validate on 175 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoc

Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Train on 1370 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch

Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Train on 1370 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Train on 1371 samples, validate on 344 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Train on 1370 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
E

Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Train on 1370 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Train on 1372 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Train on 1370 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Train on 1370 samples, validate on 343 sa

Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Train on 1372 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Train on 1371 samples, validate on 342 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Train on 1371 samples, validate on 342 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Train on 1372 samples, validate on 343 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100


Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Train on 1371 samples, validate on 342 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Train on 1371 samples, validate on 342 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Train on 1373 samples, validate on 342 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100


In [6]:
metrics

{'gene': {'overall': {'mae': 0.7299490896746337,
   'rmse': 1.2918968704150449,
   'ndcg': 0.44835543717483367},
  'annual': {'mae': array([0.41098335, 0.57326106, 0.75468637, 0.88329084, 1.02752384]),
   'rmse': array([0.74612466, 0.97067574, 1.33166402, 1.48635163, 1.66849718]),
   'ndcg': array([0.46505046, 0.36096419, 0.14908492, 0.19261638, 0.15743147])}},
 'transplant': {'overall': {'mae': 0.7718043551135649,
   'rmse': 1.2752554128409044,
   'ndcg': 0.48122490028670706},
  'annual': {'mae': array([0.75218725, 0.77410503, 0.74279975, 0.7759474 , 0.81398235]),
   'rmse': array([1.30078556, 1.29846621, 1.21739863, 1.22831554, 1.31962679]),
   'ndcg': array([0.09773368, 0.07247887, 0.02053076, 0.10604713, 0.11611737])}}}

In [8]:
metrics

{'gene': {'overall': {'mae': 0.7323172166659129,
   'rmse': 1.296150959780827,
   'ndcg': 0.39375426995433405,
   'mape': 4.407860713666826,
   'r2': 0.23216118467795743,
   'pearson': 0.5154964653492857,
   'acc': 0.34331252699477943},
  'annual': {'mae': array([0.41667306, 0.57881098, 0.75487069, 0.88225787, 1.02897349]),
   'rmse': array([0.75489929, 0.98624051, 1.31984687, 1.49443502, 1.67494569]),
   'ndcg': array([0.51731997, 0.21897933, 0.15831747, 0.16538815, 0.11366666]),
   'mape': array([3.19242965, 2.7242381 , 3.31875722, 6.27535841, 6.52852019]),
   'r2': array([0.42758378, 0.25620273, 0.18235494, 0.04360206, 0.01077944]),
   'pearson': array([0.66936208, 0.53171588, 0.47137955, 0.31838994, 0.25930228]),
   'acc': array([0.60385537, 0.21832896, 0.24209849, 0.33745923, 0.31482059])}},
 'transplant': {'overall': {'mae': 0.7708396953327952,
   'rmse': 1.2728291969176666,
   'ndcg': 0.49547971873051616,
   'mape': 3.814376280019359,
   'r2': 0.42156821867225547,
   'pearson': 