## 仅供交叉验证 前馈神经网络（NNAR）

In [1]:
import pickle
import numpy as np
import random
from utils import *
import tensorflow.keras as keras
from tensorflow.keras import regularizers
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import StratifiedKFold
import tensorflow as tf

np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)
n_input = 11

读取数据

In [2]:
# gene_arr_path = r'../output/gene_editing/es_with_decay.array'
# transplant_arr_path = r'../output/transplant/es_with_decay.array'

# gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
# transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 17, 10)
Shape of the transplant array: (5141, 17, 10)


### 截断数据
2019年为无效数据

In [3]:
# gene_arr = gene_arr[:, :-1, :]
# transplant_arr = transplant_arr[:, :-1, :]

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)


### 规范数据并获取5折交叉检验所需的训练集和验证集

In [4]:
# scaler, data = scale_data(transplant_arr, 'standard')

# # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
# X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2],transplant_arr[:, n_input, -1]
# kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

### 构建模型

In [12]:
def root_mean_squared_error(y_true, y_pred):
        return keras.backend.sqrt(keras.backend.mean(keras.backend.square(y_pred - y_true), axis=-1)) 

def build_direct_dnn_model(n_layers=2, n_units=256):
    model = keras.models.Sequential()
    model.add(Flatten())
    for i in range(n_layers):
        model.add(Dense(n_units, activation='relu'))
    model.add(Dense(5))
    
    optimizer=keras.optimizers.Adam(learning_rate=1e-4)
    model.compile(loss=root_mean_squared_error, optimizer=optimizer)
    return model

### 进行训练和评估
使用EarlyStopping和Checkpoint做训练停止方式

In [25]:
def cross_validation(X, y, y_cat, kfold, scaler, n_layers, n_units):
    overall_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[],
        'mape':[],
        'r2':[],
        'pearson':[],
        'acc':[]
    }

    annual_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[],
        'mape':[],
        'r2':[],
        'pearson':[],
        'acc':[]
    }

    tests = []
    preds = []

    for train, test in kfold.split(X, y_cat):
        model = build_direct_dnn_model(n_layers, n_units)
        history = model.fit(X[train], y[train], epochs=300, batch_size=16, verbose=0, validation_data=(X[test], y[test]),
                           callbacks=[
                               EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='auto', restore_best_weights=True)
                           ])

        y_test = y[test]
        y_pred = model.predict(X[test]).reshape(y[test].shape)

        tests.append(y_test)
        preds.append(y_pred)

        metrics = ['mae', 'rmse','ndcg', 'mape', 'r2', 'pearson', 'acc']
        for m in metrics:
            overall, annual = eval_model(m, y_test, y_pred, scaler)
            overall_metrics[m].append(overall)
            annual_metrics[m].append(annual)
    
    return overall_metrics, annual_metrics, tests, preds

In [26]:
def full_pipeline(n_layers, n_units):
    gene_arr_path = r'../output/gene_editing/es_with_decay.array'
    transplant_arr_path = r'../output/transplant/es_with_decay.array'

    gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
    transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))
    
    gene_arr = gene_arr[:, :-1, :]
    transplant_arr = transplant_arr[:, :-1, :]

    print('Shape of the gene_editing array:',gene_arr.shape)
    print('Shape of the transplant array:',transplant_arr.shape)
    
    metrics = {
        'gene':{
            'overall':{},
            'annual':{}
        },
        'transplant':{
            'overall':{},
            'annual':{}
        }
    }

    for name, dataset in zip(['gene', 'transplant'], [gene_arr, transplant_arr]):
        scaler, data = scale_data(dataset, 'standard')

        # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
        X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2], dataset[:, n_input, -1]
        kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        overall_metrics, annual_metrics, tests, preds = cross_validation(X, y, y_cat, kfold, scaler, n_layers, n_units)
        pickle.dump(tests, open('mlp_tests_{}.list'.format(name), 'wb'))
        pickle.dump(preds, open('mlp_preds_{}.list'.format(name), 'wb'))
        
        for n in [1, 5, 10, 15, 20]:
            overall_total = []

            for test, pred in zip(tests, preds):
                overall, _ = eval_model('ndcg', test, pred, scaler, n)
                overall_total.append(overall)
            print(n, np.mean(overall_total))
        
        for metric, value in overall_metrics.items():
            metrics[name]['overall'][metric] = np.mean(value)
        
        for metric, value in annual_metrics.items():
            metrics[name]['annual'][metric] = np.mean(np.array(value), axis=0)
        
        print('=====')
    
    pickle.dump(metrics, open('mlp_metrics.dict', 'wb'))

    
    return metrics

In [30]:
para_tuning_metrics = {}
for n_layers in [1, 2, 3, 4, 5]:
    for n_units in [32, 64, 128, 256, 512]:
        print(n_layers, n_units)
        para_tuning_metrics[(n_layers, n_units)] = full_pipeline(n_layers, n_units)
        print(para_tuning_metrics[(n_layers, n_units)])
#         print('gene_mae', para_tuning_metrics[(n_layers, n_units)]['gene']['overall']['mae'])
#         print('gene_rmse', para_tuning_metrics[(n_layers, n_units)]['gene']['overall']['rmse'])
#         print('transplant_mae', para_tuning_metrics[(n_layers, n_units)]['transplant']['overall']['mae'])
#         print('transplant_rmse', para_tuning_metrics[(n_layers, n_units)]['transplant']['overall']['rmse'])
        print()

1 32
Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)
1 0.42223334685804026
5 0.35001070592903033
10 0.35057763060103697
15 0.40305041028442606
20 0.4030497034479765
=====
1 0.3187007071771604
5 0.37091928396651014
10 0.41571742977640874
15 0.42182835625178783
20 0.4214619927954514
=====
{'gene': {'overall': {'mae': 0.7327803319284378, 'rmse': 1.2959335207738767, 'ndcg': 0.4030497034479765, 'mape': 4.322187601151198, 'r2': 0.23198863655088736, 'pearson': 0.5161429292508992, 'acc': 0.33884438841513537}, 'annual': {'mae': array([0.41461366, 0.57982554, 0.75731367, 0.88078355, 1.03136524]), 'rmse': array([0.74714714, 0.98322185, 1.32686298, 1.48864691, 1.67905301]), 'ndcg': array([0.48381252, 0.43845342, 0.12296899, 0.2027257 , 0.13052723]), 'mape': array([3.14992644, 2.91052808, 3.19973485, 6.06415718, 6.28659145]), 'r2': array([0.43385913, 0.25211524, 0.17361413, 0.05257655, 0.00569197]), 'pearson': array([0.67106813, 0.55292899, 0.46424108, 

1 0.42223977128139667
5 0.34959309615686723
10 0.35022593698576193
15 0.35019772097922336
20 0.3501988605832549
=====
1 0.3480388590109006
5 0.35313684558896247
10 0.4428873829498273
15 0.4468595370659624
20 0.4482770773538453
=====
{'gene': {'overall': {'mae': 0.7320178630541421, 'rmse': 1.2910796874326202, 'ndcg': 0.3501988605832549, 'mape': 4.459782244706602, 'r2': 0.2375869978743081, 'pearson': 0.5197075492542245, 'acc': 0.3367569117836921}, 'annual': {'mae': array([0.41553986, 0.58125696, 0.75214719, 0.8821497 , 1.02899561]), 'rmse': array([0.76014362, 0.96989589, 1.32368032, 1.48407937, 1.66942508]), 'ndcg': array([0.42785991, 0.46781438, 0.12289096, 0.24941002, 0.15706987]), 'mape': array([3.67323624, 2.95112136, 3.20660414, 6.11926135, 6.34868813]), 'r2': array([0.41621242, 0.27486672, 0.17689717, 0.05738007, 0.0171259 ]), 'pearson': array([0.65672264, 0.56221382, 0.46602757, 0.33842595, 0.26325945]), 'acc': array([0.60694719, 0.18735119, 0.24100286, 0.33631572, 0.3121676 ])}},

1 0.42223977128139667
5 0.34971508909315047
10 0.35047788440511796
15 0.3505180862220584
20 0.4439101583663688
=====
1 0.3584857159721236
5 0.4813405006123087
10 0.4936973307793098
15 0.5023423563616369
20 0.5041599917548714
=====
{'gene': {'overall': {'mae': 0.7316555933025326, 'rmse': 1.292113163932878, 'ndcg': 0.4439101583663688, 'mape': 4.441983489999435, 'r2': 0.23630072807780725, 'pearson': 0.519517907007148, 'acc': 0.3345340333965918}, 'annual': {'mae': array([0.41575785, 0.57505114, 0.75518838, 0.8863728 , 1.0259078 ]), 'rmse': array([0.75276908, 0.96666706, 1.32874412, 1.48898919, 1.66901906]), 'ndcg': array([0.4307069 , 0.43430325, 0.14759589, 0.24761997, 0.16334112]), 'mape': array([3.33757999, 2.94182011, 3.29339361, 6.18845993, 6.44866381]), 'r2': array([0.42924669, 0.28272051, 0.17028796, 0.05197746, 0.01740001]), 'pearson': array([0.67069662, 0.55984527, 0.46272303, 0.32637217, 0.2626512 ]), 'acc': array([0.60803638, 0.18807228, 0.23795317, 0.32835398, 0.31025436])}}, 't

1 0.42223977128139667
5 0.3494014970503774
10 0.3501741777784753
15 0.3501538256259707
20 0.4479059859870664
=====
1 0.3322732037069516
5 0.39771784470344634
10 0.42838496453945785
15 0.4352213959180286
20 0.44159807944431
=====
{'gene': {'overall': {'mae': 0.7335684832957228, 'rmse': 1.2922944350633805, 'ndcg': 0.4479059859870664, 'mape': 4.482581807464034, 'r2': 0.23601745397593019, 'pearson': 0.5190649426118484, 'acc': 0.33014447882981324}, 'annual': {'mae': array([0.4182274 , 0.57988144, 0.7562754 , 0.88881726, 1.02464092]), 'rmse': array([0.76011438, 0.96945587, 1.32721729, 1.48671481, 1.66772947]), 'ndcg': array([0.46887542, 0.42993897, 0.14918801, 0.20966424, 0.1430828 ]), 'mape': array([3.33366699, 2.96261286, 3.26095922, 6.36504602, 6.49062395]), 'r2': array([0.41875209, 0.27724952, 0.17341078, 0.05460224, 0.01863202]), 'pearson': array([0.66026975, 0.55496676, 0.46336553, 0.32159393, 0.26355995]), 'acc': array([0.6061438 , 0.18314441, 0.2266052 , 0.32532006, 0.30950891])}}, '

1 0.42223977128139667
5 0.349887929943357
10 0.3504874385961415
15 0.3504594149077095
20 0.4441676509158631
=====
1 0.3771619952945807
5 0.3759515981509415
10 0.45751151609604734
15 0.45753522557009985
20 0.46118478829179804
=====
{'gene': {'overall': {'mae': 0.7322117046189256, 'rmse': 1.2895884799467727, 'ndcg': 0.4441676509158631, 'mape': 4.565080551859867, 'r2': 0.23915423012341078, 'pearson': 0.523814303351986, 'acc': 0.33340283508647045}, 'annual': {'mae': array([0.41658563, 0.57941794, 0.75621436, 0.88637285, 1.02246773]), 'rmse': array([0.75396287, 0.96788499, 1.32259892, 1.48485103, 1.66716417]), 'ndcg': array([0.42825082, 0.46148052, 0.12540957, 0.25852042, 0.15766231]), 'mape': array([3.33792636, 3.26193072, 3.35263134, 6.41775657, 6.45515778]), 'r2': array([0.42738219, 0.28000031, 0.17745055, 0.0575318 , 0.01944424]), 'pearson': array([0.67018035, 0.56731073, 0.46592442, 0.33185558, 0.26863327]), 'acc': array([0.59707434, 0.18314659, 0.23982435, 0.33328828, 0.3136806 ])}}, 

In [32]:
# pickle.dump(para_tuning_metrics, open('para_tuning_metrics.dict', 'wb'))

In [13]:
metrics = full_pipeline()

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)
Train on 2112 samples, validate on 531 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Train on 2113 samples, validate on 530 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Train on 2115 samples, validate on 528 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100

Epoch 27/100
Train on 2116 samples, validate on 527 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Train on 2116 samples, validate on 527 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
1 0.42223977128139667
5 0.34985699764038586
10 0.4080483947050876
15 0.4620946659191977
20 0.47781230459386864
=====
Train on 4111 samples, validate on 1030 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/

Train on 4112 samples, validate on 1029 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Train on 4112 samples, validate on 1029 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Train on 4114 samples, validate on 1027 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Train on 4115 samples, validate on 1026 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/1

1 0.3582524384574738
5 0.3626105312943469
10 0.4453733401969999
15 0.4484424672821711
20 0.4540192432896492
=====


In [6]:
metrics

{'gene': {'overall': {'mae': 0.7299490896746337,
   'rmse': 1.2918968704150449,
   'ndcg': 0.44835543717483367,
   'mape': 4.43813367441968,
   'r2': 0.2366877334464476,
   'pearson': 0.5213191822517468,
   'acc': 0.34326412607594264},
  'annual': {'mae': array([0.41098335, 0.57326106, 0.75468637, 0.88329084, 1.02752384]),
   'rmse': array([0.74612466, 0.97067574, 1.33166402, 1.48635163, 1.66849718]),
   'ndcg': array([0.46505046, 0.36096419, 0.14908492, 0.19261638, 0.15743147]),
   'mape': array([3.48005609, 2.8328771 , 3.14869389, 6.23453968, 6.49450161]),
   'r2': array([0.43897391, 0.27529658, 0.16778358, 0.05561811, 0.01792775]),
   'pearson': array([0.67472678, 0.56063183, 0.46457533, 0.32847791, 0.26389998]),
   'acc': array([0.60726591, 0.22830786, 0.23867794, 0.33481486, 0.30725406])}},
 'transplant': {'overall': {'mae': 0.7718043551135649,
   'rmse': 1.2752554128409044,
   'ndcg': 0.48122490028670706,
   'mape': 3.659302529390196,
   'r2': 0.4193337345315111,
   'pearson': 0.

In [6]:
metrics

{'gene': {'overall': {'mae': 0.7299490896746337,
   'rmse': 1.2918968704150449,
   'ndcg': 0.44835543717483367},
  'annual': {'mae': array([0.41098335, 0.57326106, 0.75468637, 0.88329084, 1.02752384]),
   'rmse': array([0.74612466, 0.97067574, 1.33166402, 1.48635163, 1.66849718]),
   'ndcg': array([0.46505046, 0.36096419, 0.14908492, 0.19261638, 0.15743147])}},
 'transplant': {'overall': {'mae': 0.7718043551135649,
   'rmse': 1.2752554128409044,
   'ndcg': 0.48122490028670706},
  'annual': {'mae': array([0.75218725, 0.77410503, 0.74279975, 0.7759474 , 0.81398235]),
   'rmse': array([1.30078556, 1.29846621, 1.21739863, 1.22831554, 1.31962679]),
   'ndcg': array([0.09773368, 0.07247887, 0.02053076, 0.10604713, 0.11611737])}}}

In [7]:
metrics

{'gene': {'overall': {'mae': 0.7299490896746337,
   'rmse': 1.2918968704150449,
   'ndcg': 0.44835543717483367,
   'mape': 4.43813367441968,
   'r2': 0.2366877334464476,
   'pearson': 0.5213191822517468,
   'acc': 0.34326412607594264},
  'annual': {'mae': array([0.41098335, 0.57326106, 0.75468637, 0.88329084, 1.02752384]),
   'rmse': array([0.74612466, 0.97067574, 1.33166402, 1.48635163, 1.66849718]),
   'ndcg': array([0.46505046, 0.36096419, 0.14908492, 0.19261638, 0.15743147]),
   'mape': array([3.48005609, 2.8328771 , 3.14869389, 6.23453968, 6.49450161]),
   'r2': array([0.43897391, 0.27529658, 0.16778358, 0.05561811, 0.01792775]),
   'pearson': array([0.67472678, 0.56063183, 0.46457533, 0.32847791, 0.26389998]),
   'acc': array([0.60726591, 0.22830786, 0.23867794, 0.33481486, 0.30725406])}},
 'transplant': {'overall': {'mae': 0.7718043551135649,
   'rmse': 1.2752554128409044,
   'ndcg': 0.48122490028670706,
   'mape': 3.659302529390196,
   'r2': 0.4193337345315111,
   'pearson': 0.