## 仅供交叉验证 前馈神经网络（NNAR）-按趋势分类

In [1]:
import pickle
import numpy as np
import random
from utils import *
import tensorflow.keras as keras
from tensorflow.keras import regularizers
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import StratifiedKFold
import tensorflow as tf

np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)
n_input = 11

读取数据

In [2]:
# gene_arr_path = r'../output/gene_editing/es_with_decay.array'
# transplant_arr_path = r'../output/transplant/es_with_decay.array'

# gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
# transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 17, 10)
Shape of the transplant array: (5141, 17, 10)


### 截断数据
2019年为无效数据

In [3]:
# gene_arr = gene_arr[:, :-1, :]
# transplant_arr = transplant_arr[:, :-1, :]

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)


### 规范数据并获取5折交叉检验所需的训练集和验证集

In [4]:
# scaler, data = scale_data(transplant_arr, 'standard')

# # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
# X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2],transplant_arr[:, n_input, -1]
# kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

### 按趋势划分

In [2]:
def split_data_by_trend(data, targets):
    up_data = []
    down_data = []
    up_target = []
    down_target = []
    
    for i in range(len(data)):
        a, b = np.polyfit(range(len(data[i])), data[i, :, -2].reshape(-1), 1)
        if a > 0:
            up_data.append(data[i])
            up_target.append(targets[i])
        else:
            down_data.append(data[i])
            down_target.append(targets[i])
    return np.array(up_data), np.array(up_target), np.array(down_data), np.array(down_target)

### 构建模型

In [3]:
def root_mean_squared_error(y_true, y_pred):
        return keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_pred - y_true), axis=-1)) 

# def build_direct_dnn_model():
#     model = keras.models.Sequential()
#     model.add(Flatten())
#     model.add(Dense(256, activation='relu'))
#     model.add(Dense(256, activation='relu'))
#     model.add(Dense(5))
    
#     optimizer=keras.optimizers.Adam(learning_rate=1e-4)
#     model.compile(loss=root_mean_squared_error, optimizer=optimizer)
#     return model
def build_direct_dnn_model(n_layers=2, n_units=256):
    model = keras.models.Sequential()
    model.add(Flatten())
    for i in range(n_layers):
        model.add(Dense(n_units, activation='relu'))
    model.add(Dense(5))
    
    optimizer=keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(loss=root_mean_squared_error, optimizer=optimizer)
    return model

### 进行训练和评估
使用EarlyStopping和Checkpoint做训练停止方式

In [4]:
def cross_validation(X, y, y_cat, kfold, scaler, n_layers, n_units):
    overall_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[],
        'mape':[],
        'r2':[],
        'pearson':[],
        'acc':[]
    }

    annual_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[],
        'mape':[],
        'r2':[],
        'pearson':[],
        'acc':[]
    }
    
    for train, test in kfold.split(X, y_cat):
        X_train = X[train]
        y_train = y[train]
        X_test = X[test]
        y_test = y[test]
        models = []
        
        # 按总量划分数据集
        X_train1, y_train1, X_train2, y_train2 = split_data_by_trend(X_train, y_train)
        train_xs = [X_train1, X_train2]
        train_ys = [y_train1, y_train2]
        
        X_test1, y_test1, X_test2, y_test2 = split_data_by_trend(X_test, y_test)
        test_xs = [X_test1, X_test2]
        test_ys = [y_test1, y_test2]
        i_s = [1, 2]
        
        # 训练
        for i in range(len(i_s)):
            model = build_direct_dnn_model(n_layers, n_units)
            history = model.fit(train_xs[i], train_ys[i], epochs=300, batch_size=16, verbose=1, validation_data=(test_xs[i], test_ys[i]),
                           callbacks=[
                               EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='auto', restore_best_weights=True)
                           ])
            models.append(model)
        
        # 预测
        y_test = []
        y_pred = []
        for i in range(len(i_s)):
            y_test.append(test_ys[i])
            y_pred.append(models[i].predict(test_xs[i]).reshape(test_ys[i].shape))
        
        y_test = np.concatenate(y_test)
        y_pred = np.concatenate(y_pred)

        metrics = ['mae', 'rmse','ndcg', 'mape', 'r2', 'pearson', 'acc']
        for m in metrics:
            overall, annual = eval_model(m, y_test, y_pred, scaler)
            overall_metrics[m].append(overall)
            annual_metrics[m].append(annual)
    
    return overall_metrics, annual_metrics


In [5]:
def full_pipeline():
    gene_arr_path = r'../output/gene_editing/es_with_decay.array'
    transplant_arr_path = r'../output/transplant/es_with_decay.array'

    gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
    transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))
    
    gene_arr = gene_arr[:, :-1, :]
    transplant_arr = transplant_arr[:, :-1, :]

    print('Shape of the gene_editing array:',gene_arr.shape)
    print('Shape of the transplant array:',transplant_arr.shape)
    
    metrics = {
        'gene':{
            'overall':{},
            'annual':{}
        },
        'transplant':{
            'overall':{},
            'annual':{}
        }
    }
    
    for name, dataset in zip(['gene', 'transplant'], [gene_arr, transplant_arr]):
        scaler, data = scale_data(dataset, 'standard')

        # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
        X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2], dataset[:, n_input, -1]
        kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
#         overall_metrics, annual_metrics = cross_validation(X, y, y_cat, kfold, scaler)
        if name == 'gene':
            overall_metrics, annual_metrics = cross_validation(X, y, y_cat, kfold, scaler, 3, 128)
        elif name == 'transplant':
            overall_metrics, annual_metrics = cross_validation(X, y, y_cat, kfold, scaler, 5, 256)
        
        for metric, value in overall_metrics.items():
            metrics[name]['overall'][metric] = np.mean(value)
        
        for metric, value in annual_metrics.items():
            metrics[name]['annual'][metric] = np.mean(np.array(value), axis=0)
    
    pickle.dump(metrics, open('mlp_metrics.dict', 'wb'))
    
    return metrics

In [6]:
metrics = full_pipeline()

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)
Train on 988 samples, validate on 256 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Train on 1124 samples, validate on 275 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Train on 998 samples, validate on 246 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
E

Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Train on 995 samples, validate on 249 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Train on 1120 samples, validate on 279 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Train on 991 samples, validate on 253 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epo

Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Train on 1004 samples, validate on 240 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 38/300
Epoch 39/300
Epoch 40/300
Epoch 41/300
Epoch 42/300
Epoch 43/300
Epoch 44/300
Epoch 45/300
Epoch 46/300
Epoch 47/300
Epoch 48/300
Epoch 49/300
Epoch 50/300
Epoch 51/300
Epoch 52/300
Epoch 53/300
Epoch 54/300
Epoch 55/300
Epoch 56/300
Epoch 57/300
Epoch 58/300
Epoch 59/300
Epoch 60/300
Epoch 61/300
Epoch 62/300
Epoch 6

Epoch 64/300
Epoch 65/300
Epoch 66/300
Epoch 67/300
Epoch 68/300
Epoch 69/300
Epoch 70/300
Epoch 71/300
Epoch 72/300
Epoch 73/300
Epoch 74/300
Train on 1112 samples, validate on 287 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Train on 3531 samples, validate on 884 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Train on 580 samples, validate on 146 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
E

Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Train on 3542 samples, validate on 873 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Train on 570 samples, validate on 156 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Train on 3535 samples, validate on 880 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Train on 577 sam

Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Train on 3515 samples, validate on 900 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Train on 599 samples, validate on 127 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Train on 3537 samples, validate on 878 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Train on 578 samples, validate on 148 samples
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch

In [7]:
metrics

{'gene': {'overall': {'mae': 0.7291975904675908,
   'rmse': 1.2871031342506778,
   'ndcg': 0.4673787741071691,
   'mape': 4.422155277734545,
   'r2': 0.24244890230809224,
   'pearson': 0.524190051690691,
   'acc': 0.3317444264831185},
  'annual': {'mae': array([0.41338131, 0.57355205, 0.74987375, 0.88439481, 1.02478603]),
   'rmse': array([0.75304537, 0.95914567, 1.31467579, 1.48858523, 1.6651921 ]),
   'ndcg': array([0.4623395 , 0.43362874, 0.14620785, 0.19055866, 0.15106367]),
   'mape': array([3.42139588, 2.99739742, 3.24076012, 6.14907401, 6.30214895]),
   'r2': array([0.42885064, 0.29498014, 0.18880622, 0.05274659, 0.02198769]),
   'pearson': array([0.66843551, 0.56805361, 0.47478036, 0.32031706, 0.26847979]),
   'acc': array([0.58150882, 0.18466888, 0.23798611, 0.33935394, 0.31520438])}},
 'transplant': {'overall': {'mae': 0.7658888161691443,
   'rmse': 1.271241089721329,
   'ndcg': 0.4850424174440864,
   'mape': 3.8585934696807804,
   'r2': 0.42290244232094354,
   'pearson': 0.6

In [7]:
metrics

{'gene': {'overall': {'mae': 0.7297824076242991,
   'rmse': 1.288492394560198,
   'ndcg': 0.450425276583966,
   'mape': 4.379720797710389,
   'r2': 0.24096271379741033,
   'pearson': 0.5239841606169792,
   'acc': 0.33765443023983865},
  'annual': {'mae': array([0.41629491, 0.57294854, 0.75242516, 0.88069069, 1.02655274]),
   'rmse': array([0.75623629, 0.96455199, 1.31755807, 1.48652634, 1.66621578]),
   'ndcg': array([0.42948659, 0.36199947, 0.15691518, 0.21310819, 0.14353297]),
   'mape': array([3.27338697, 2.92294838, 3.34444352, 6.13713268, 6.22069244]),
   'r2': array([0.42536017, 0.28595323, 0.18519127, 0.05494907, 0.02053747]),
   'pearson': array([0.66756681, 0.55964269, 0.47305132, 0.33346943, 0.2686703 ]),
   'acc': array([0.57891309, 0.20543948, 0.25955221, 0.3371119 , 0.30725548])}},
 'transplant': {'overall': {'mae': 0.7741190799377897,
   'rmse': 1.2768263814311376,
   'ndcg': 0.521502225186101,
   'mape': 3.5752431968110856,
   'r2': 0.4178778432693182,
   'pearson': 0.65