# Load library

In [3]:
import copy
import ast
import gzip
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from glob import glob
from tqdm import tqdm

from train_model import make_data, tuning_ElasticNet, train_Enet, train_GBLUP, GBLUP_coef

# Parameter tuning

parameter tuning for ElasticNet

In [None]:
genotype = "../genotype.csv"
phenotype = "../phenotype.csv"
CV=5
optuna_seed=618
trait_names=["biomass", "grain_number"]

# parameter tuning for each trait
for trait_name in trait_names:

    # make test data & training data by 5fold-CV
    test_data, train_data, position_data = make_data(trait_name, genotype, phenotype, CV, seed=1024)
    
    # parameter tuning for each CV dataset
    tuning_params = []
    for i in range(CV):
    
        # parameter tuning
        CV_params = tuning_ElasticNet(train_data[i][0], train_data[i][1])
        tuning_params.append([trait_name, i, CV_params])
    
    # save tuning parameters
    pd.DataFrame(tuning_params).to_csv("../Enet_{}_params.csv".format(trait_name))

# Train models & check accuracy

In [9]:
def train_multi(inputs):
    model = inputs[4]
    print(f"Model {model} start")
    if model == "GBLUP":
        test_pred, estimated_h2, r2 = train_GBLUP(inputs[1], inputs[2], inputs[3])
        coefs = GBLUP_coef(inputs[2], inputs[3])
    elif model == "Enet":
        test_pred, r2, coefs = train_Enet(inputs[1], inputs[2], inputs[3])
    return [inputs[0], inputs[1][1].values, test_pred, r2, coefs]

In [13]:
genotype = "../genotype.csv"
phenotype = "../phenotype.csv"
CV=5
optuna_seed=1024
trait_names = ["biomass", "grain_number"]
models = ["GBLUP", "Enet"]


# train models for each trait & each statistical model
for trait_name in trait_names:
    
    print(trait_name)

    summary = []
    
    # make test & training dataset
    test_data, train_data, position_data = make_data(trait_name, genotype, phenotype, CV, seed=1024, plot=False)

    # train model for each CV data & each model
    for model in models:

        # read parameter
        if model != "GBLUP":
            params = pd.read_csv("../{}_all_{}_params.csv".format(model, trait_name), index_col=0)
            params.columns = ["trait", "CV", "params"]
            params = ast.literal_eval(params.loc[(params["trait"] == trait_name) & (params["CV"] == i), "params"].values[0])
        else:
            params = genotype

        # for each CV        
        for i in range(CV):
            
            # train model
            result = train_multi([i, test_data[i], train_data[i], params, model])

            # add result to summary
            tmp_result = [trait_name, model]
            tmp_result.extend(result)
            summary.append(tmp_result)

    with gzip.open(f'../{trait_name}.pic.bin.gz', 'wb') as p:
        pickle.dump(summary, p) 

biomass
The phenotype data: 219 The genotype data: 219
The merge data: (219, 1669)
Model GBLUP start
Model GBLUP start
Model GBLUP start
Model GBLUP start
Model GBLUP start
grain_number
The phenotype data: 219 The genotype data: 219
The merge data: (219, 1669)
Model GBLUP start
Model GBLUP start
Model GBLUP start
Model GBLUP start
Model GBLUP start


# check results

In [15]:
trait_names = ["biomass", "grain_number"]

summary = pd.DataFrame()
for trait_name in trait_names:
    print(trait_name)
    with gzip.open(f'../{trait_name}.pic.bin.gz', 'rb') as p:
        tmp = pd.DataFrame(pickle.load(p))
        tmp.columns = ["trait", "model", "CV", "test_observed", "test_preds", "r2", "coefs"]
    tmp.to_csv(f'../{trait_name}train_result.csv', index+None)
    display(tmp)

biomass


Unnamed: 0,trait,model,CV,test_observed,test_preds,r2,coefs
0,biomass,GBLUP,0,"[121.5, 122.5, 123.5, 192.0, 162.0, 138.0, 141...","[136.56411954068, 124.32886117157, 125.5183551...",0.168474,"[0.00656690087518049, 0.0062964199030434, 0.00..."
1,biomass,GBLUP,1,"[147.0, 111.5, 52.0, 133.5, 142.5, 127.5, 161....","[127.171144041862, 126.588836856018, 127.19699...",0.373708,"[0.00417132160508687, 0.00391747322053258, 0.0..."
2,biomass,GBLUP,2,"[96.5, 127.5, 44.0, 119.5, 127.5, 155.5, 128.0...","[144.027460069513, 144.027460069513, 130.76327...",-0.04797,"[0.037061205210786, 0.0342689035945309, 0.0342..."
3,biomass,GBLUP,3,"[99.0, 120.5, 122.0, 136.5, 167.5, 110.0, 127....","[123.407818726475, 135.956173265438, 125.06612...",0.177288,"[0.0089628601648552, 0.00865554778366829, 0.00..."
4,biomass,GBLUP,4,"[129.5, 108.5, 62.0, 139.5, 133.5, 100.5, 118....","[138.221689892567, 124.040444226178, 124.32504...",0.03947,"[0.00752646410350697, 0.00863086739160251, 0.0..."


grain_number


Unnamed: 0,trait,model,CV,test_observed,test_preds,r2,coefs
0,grain_number,GBLUP,0,"[906.0, 755.0, 903.5, 1196.0, 989.0, 816.5, 79...","[852.59228512064, 756.607602091889, 768.909636...",0.101572,"[0.0430112285036432, 0.0504495778710091, 0.050..."
1,grain_number,GBLUP,1,"[1101.5, 671.5, 275.5, 897.5, 813.5, 780.0, 10...","[775.738335086172, 771.219871960075, 775.78759...",0.345047,"[0.0197648790477381, 0.0190827045635483, 0.019..."
2,grain_number,GBLUP,2,"[598.0, 781.0, 351.5, 876.5, 864.5, 868.5, 849...","[866.323702613034, 866.323702613034, 796.14183...",-0.012485,"[0.1053885466202, 0.117863733012091, 0.1178637..."
3,grain_number,GBLUP,3,"[822.0, 786.5, 791.5, 1031.5, 898.5, 731.0, 63...","[760.914483122914, 827.758897013637, 766.74295...",0.230053,"[0.0365382696722115, 0.0393522423673548, 0.039..."
4,grain_number,GBLUP,4,"[780.5, 603.5, 445.0, 863.0, 946.5, 879.0, 744...","[838.849661385455, 755.535915063864, 755.97624...",0.065544,"[0.03973543044088, 0.0484097751276084, 0.04840..."
