## 仅供交叉验证 线性回归

In [1]:
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from utils import *
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import StratifiedKFold

np.random.seed(42)
random.seed(42)
n_input = 11

读取数据

In [2]:
# gene_arr_path = r'../output/gene_editing/es_with_decay.array'
# transplant_arr_path = r'../output/transplant/es_with_decay.array'

# gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
# transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 17, 10)
Shape of the transplant array: (5141, 17, 10)


### 截断数据
2019年为无效数据

In [3]:
# gene_arr = gene_arr[:, :-1, :]
# transplant_arr = transplant_arr[:, :-1, :]

# print('Shape of the gene_editing array:',gene_arr.shape)
# print('Shape of the transplant array:',transplant_arr.shape)

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)


## 规范数据并获取5折交叉检验所需的训练集和验证集

In [6]:
# scaler, data = scale_data(transplant_arr, 'standard')

# # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
# X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2],transplant_arr[:, n_input, -1]
# kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

### 构建模型，训练并评估

In [2]:
def cross_validation(X, y, y_cat, kfold, scaler):

    overall_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[]
    }

    annual_metrics = {
        'mae':[],
        'rmse':[],
        'ndcg':[]
    }

    for train, test in kfold.split(X, y_cat):
        model = LinearRegression()
        model.fit(X[train].reshape(len(train), -1), y[train])

        y_test = y[test]
        y_pred = model.predict(X[test].reshape(len(test), -1)).reshape(y[test].shape)

        metrics = ['mae', 'rmse','ndcg']
        for m in metrics:
            overall, annual = eval_model(m, y_test, y_pred, scaler)
            overall_metrics[m].append(overall)
            annual_metrics[m].append(annual)
    
    return overall_metrics, annual_metrics

In [3]:
def full_pipeline():
    gene_arr_path = r'../output/gene_editing/es_with_decay.array'
    transplant_arr_path = r'../output/transplant/es_with_decay.array'

    gene_arr = pickle.load(open(gene_arr_path, mode='rb'))
    transplant_arr = pickle.load(open(transplant_arr_path, mode='rb'))
    
    gene_arr = gene_arr[:, :-1, :]
    transplant_arr = transplant_arr[:, :-1, :]

    print('Shape of the gene_editing array:',gene_arr.shape)
    print('Shape of the transplant array:',transplant_arr.shape)
    
    metrics = {
        'gene':{
            'overall':{},
            'annual':{}
        },
        'transplant':{
            'overall':{},
            'annual':{}
        }
    }
    
    for name, dataset in zip(['gene', 'transplant'], [gene_arr, transplant_arr]):
        scaler, data = scale_data(dataset, 'standard')

        # 用预测第二年的类别变量作为分成Kfold的依据，不支持浮点数
        X, y, y_cat = data[:, :n_input, :], data[:, n_input:, -2], dataset[:, n_input, -1]
        kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        overall_metrics, annual_metrics = cross_validation(X, y, y_cat, kfold, scaler)
        
        for metric, value in overall_metrics.items():
            metrics[name]['overall'][metric] = np.mean(value)
        
        for metric, value in annual_metrics.items():
            metrics[name]['annual'][metric] = np.mean(np.array(value), axis=0)
    
    pickle.dump(metrics, open('lr_metrics.dict', 'wb'))
    
    return metrics

In [4]:
metrics = full_pipeline()

Shape of the gene_editing array: (2643, 16, 10)
Shape of the transplant array: (5141, 16, 10)


In [5]:
metrics

{'gene': {'overall': {'mae': 0.841332715085354,
   'rmse': 1.521748249160741,
   'ndcg': 0.27832640637405115},
  'annual': {'mae': array([0.50798695, 0.65653711, 0.86775894, 1.00319185, 1.17118872]),
   'rmse': array([1.20531149, 1.2539588 , 1.6121069 , 1.61055601, 1.7801977 ]),
   'ndcg': array([0.32866693, 0.21297316, 0.17334743, 0.16026472, 0.07336666])}},
 'transplant': {'overall': {'mae': 0.8308358387334762,
   'rmse': 1.3109310386248048,
   'ndcg': 0.38689913510732726},
  'annual': {'mae': array([0.81742317, 0.82753902, 0.79332126, 0.83119176, 0.88470399]),
   'rmse': array([1.36195737, 1.31246435, 1.23893937, 1.26386641, 1.36408821]),
   'ndcg': array([0.0335395 , 0.02458765, 0.0193729 , 0.06673059, 0.09900777])}}}

In [10]:
overall_metrics

{'mae': [0.8395327215970434,
  0.8266725403206769,
  0.8349507815194999,
  0.8247750116750805,
  0.828248138555081],
 'rmse': [1.3425170271892428,
  1.289811886300267,
  1.2832247565538735,
  1.2876725496977182,
  1.3514289733829223],
 'ndcg': [0.07177541672729326,
  0.6606751289426189,
  0.5578098604528476,
  0.5977893977536649,
  0.04644587166021171]}