In [2]:
# --------------------------------------------------------------------------

# ----------------- simpls models after sample_selection functions

# -------------------------------------------------------------------------


# ¡¡¡ --- !!! # ---> modules and data cases

# --- system modules

import sys
import datetime
import os


base_dir = "/sample_selection_simpls"

# --- data handling modules

import numpy as np
import pandas as pd
import scipy.io as sp_io
import scipy as sp

# --- visualization modules

import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import rcParams
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# --- my modules

methods_dir = base_dir + '/methodology'  
sys.path.insert(0, methods_dir + '/model_building')
sys.path.insert(0, methods_dir + '/read_data')
sys.path.insert(0, methods_dir + '/sample_selection')
from class_chemometrics_data import chemometrics_data
from class_sample_selection import sample_selection
import simpls_module




# ¡¡¡ --- !!! # ---> base working directory and available data cases


# ************************************ init --- user 
cases_dict = {"d0001": ["d0001_corn", "d0001_data_prepared_01"]}
# ************************************ end --- user 


print("--------- imports loaded ----------")


experimentID = "exp001"

# ¡¡¡ --- !!! # ---> data


# ************************************ init --- user 
caseID_key = "d0001"
# ************************************ end --- user 

case_dir = cases_dict[caseID_key][0]
dname = cases_dict[caseID_key][1]
data_dir = '/data/' + case_dir + '/data_prepared/'





# ************************************ init --- user
data_class = chemometrics_data(base_dir + data_dir + dname + '.mat', 
                               data_identifier = data_dir + dname,
                               include_val = False,
                               include_test = True,
                               include_unlabeled = False,
                               y_all_range = False,
                               y_range = np.array([0]),
                               obs_all_cal = True,
                              shuffle = False)
# ************************************ end --- user



print(data_class.ncal, data_class.K)
print(data_class.get_test()["xtest"].shape)


print("--------- data loaded for " + data_class.data_identifier + "----------")




--------- imports loaded ----------
56 700
(24, 700)
--------- data loaded for /data/d0001_corn/data_prepared/d0001_data_prepared_01!*moisture*!----------


In [26]:
# --- initialize numba functions

output_pls = simpls_module.simpls_fit(xx=data_class.get_cal()["xcal"], yy=data_class.get_cal()["ycal"], nlv=14)
ytest_pred = simpls_module.simpls_predict(data_class.get_test()["xtest"],  output_pls[0],output_pls[1],output_pls[2])
rmsep = simpls_module.rmse(data_class.get_test()["ytest"], ytest_pred, np.ones(ytest_pred.shape[0]))
r2 = simpls_module.r2(data_class.get_test()["ytest"], ytest_pred, np.ones(ytest_pred.shape[0]))
cv_output = simpls_module.simpls_univariate_cv(xx=data_class.get_cal()["xcal"], yy=data_class.get_cal()["ycal"], total_nlv=40, number_splits=10, number_repetitions=4)


print("done")

done


In [27]:
# --- functions for model performance

def model_performance_cv_val_test(X0,Y0,selected_rows,Xtest, Ytest,total_lv=25, cv_reps = 2):



    Xc = X0 - X0.mean(axis=0)
    Yc = Y0 - Y0.mean(axis=0)  


    # ¡¡¡ --- !!! get samples

    cal_samples = selected_rows.copy()
    n_ss = cal_samples.sum()
 


    # --- train, cv, val, test

    Xcal = X0[cal_samples==1,:]
    Ycal = Y0[cal_samples==1,:]
    Xval = X0[cal_samples==0,:]
    Yval = Y0[cal_samples==0,:]  
    
    
    trained_pls = simpls_module.simpls_fit(xx=Xcal, yy=Ycal, nlv=total_lv)
    yval_pred = simpls_module.simpls_predict(Xval,  trained_pls[0],trained_pls[1],trained_pls[2])
    ytest_pred = simpls_module.simpls_predict(Xtest,  trained_pls[0],trained_pls[1],trained_pls[2])
    
    
    rmseval = simpls_module.rmse(Yval, yval_pred, np.ones(yval_pred.shape[0]))[:,0]
    rmsep = simpls_module.rmse(Ytest, ytest_pred, np.ones(ytest_pred.shape[0]))[:,0]
    
    r2val = simpls_module.r2(Yval, yval_pred, np.ones(yval_pred.shape[0]))[:,0]
    r2p = simpls_module.r2(Ytest, ytest_pred, np.ones(ytest_pred.shape[0]))[:,0]
    
    
    cv_output = simpls_module.simpls_univariate_cv(xx=Xcal, yy=Ycal, total_nlv=total_lv, number_splits=10, number_repetitions = cv_reps)
    
    rmsecv = cv_output[0]
    r2cv = cv_output[1]


    
    output = {'rmsecv':rmsecv,
              'rmseval': rmseval,
              'rmsep':rmsep,
             'r2cv':r2cv,
             'r2val':r2val,
             'r2p':r2p}
    
    return output
    


In [28]:
# --- get design and selected samples

design_df = pd.read_pickle(base_dir + "/experiments/" + experimentID + "/output/" + caseID_key + "_01_design_selected_samples.pkl")


In [29]:
# --- run all pls models

from datetime import datetime

print('start: ',datetime.now())


for jj in range(data_class.get_cal()["ycal"].shape[1]):
    
    print(data_class.y_names[jj])

    xx = data_class.get_cal()["xcal"]
    yy = np.ascontiguousarray(data_class.get_cal()["ycal"][:,jj:(jj+1)])
    xx_test = data_class.get_test()["xtest"]
    yy_test = np.ascontiguousarray(data_class.get_test()["ytest"][:,jj:(jj+1)])


    pls_performance = {}


    for ii in range(design_df.shape[0]):

        if ii%100==0:

            print(ii)


        current_run = dict(design_df.iloc[ii])

        selected_samples = current_run["selected_samples"]

        current_performance = model_performance_cv_val_test(xx,yy,selected_samples,xx_test,yy_test,total_lv=15, cv_reps = 1)

        current_run.update(current_performance)

        pls_performance[str(ii)] = current_run

    print('finish: ',datetime.now())


    # ¡¡¡ --- !!! ---> save output 

    df_output = pd.DataFrame.from_dict(pls_performance, orient="index")
    df_output.to_pickle(base_dir + "/experiments/" + experimentID + "/output/" + caseID_key + "_" + data_class.y_names[jj] + "_numba_02_pls_performance.pkl")

start:  2020-09-04 17:23:05.091489
moisture
0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
finish:  2020-09-04 17:24:03.782008


# Model optimizer table with exhaustive sample selection

In [6]:
df_output

Unnamed: 0,npc,method_name,sample_size,total_lv,selected_samples,rmsecv,rmseval,rmsep,r2cv,r2val,r2p
0,1,random,20,15,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, ...","[[[0.32680534016169893]], [[0.3071336095395987...","[0.34694136097003186, 0.27624429038729753, 0.2...","[0.29237126255074897, 0.225305069011708, 0.180...","[[[0.013959605098109673]], [[0.129094355184945...","[0.2846100311028803, 0.5464582080676124, 0.684...","[0.35456759876355115, 0.6167136766450093, 0.75..."
1,1,random,21,15,"[1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, ...","[[[0.3613428826474382]], [[0.3312580218410601]...","[0.32462351023335523, 0.25472007735901825, 0.1...","[0.26503399689104995, 0.21715678044121361, 0.1...","[[[0.12628131462634118]], [[0.2657137056331015...","[0.2459870011242179, 0.5357571221570994, 0.726...","[0.46962313812019496, 0.6439359022328821, 0.81..."
10,4,random,54,15,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, ...","[[[0.3346977510946659]], [[0.2840946106587239]...","[0.2161736677622658, 0.038142529824756834, 0.0...","[0.24956783471730795, 0.21120245276402252, 0.1...","[[[0.21833371498077647]], [[0.4368270562230201...","[-4.52115484803767, 0.8281128802655392, 0.2309...","[0.5297177154053496, 0.6631943934858446, 0.842..."
100,4,puchwein,30,15,"[0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, ...","[[[0.3380367294169447]], [[0.28206637906481546...","[0.33940510596447265, 0.2854998122731694, 0.17...","[0.269593584992753, 0.21645456568591925, 0.147...","[[[0.20500135746760828]], [[0.4464698876746898...","[0.2206879606460712, 0.4485751971196129, 0.800...","[0.4512171825427408, 0.6462349711279095, 0.836..."
1000,7,duplex,29,15,"[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, ...","[[[0.335780646148973]], [[0.2906535125290237]]...","[0.322900253365154, 0.27210285754139557, 0.148...","[0.2511505747515998, 0.21816468676005044, 0.13...","[[[0.1835082538231696]], [[0.3882251357075366]...","[0.3295071997105611, 0.5238721859578266, 0.857...","[0.5237338124195039, 0.6406229768879854, 0.859..."
1001,7,duplex,30,15,"[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, ...","[[[0.33705383649707094]], [[0.2802498317103369...","[0.32862703123775583, 0.2757014925743955, 0.14...","[0.2525000172693818, 0.218544578875561, 0.1356...","[[[0.1502679309062801]], [[0.41254557399579816...","[0.3286235197667241, 0.5274608557597049, 0.861...","[0.5186020666183313, 0.6393703143671343, 0.861..."
1002,7,duplex,31,15,"[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, ...","[[[0.3209745490953467]], [[0.27291637717219946...","[0.33479217683060203, 0.2801361469736111, 0.14...","[0.2524387274474547, 0.21886152366919245, 0.13...","[[[0.20387083505106374]], [[0.4244255141088421...","[0.32892271699391096, 0.5301489116143671, 0.87...","[0.5188357395745078, 0.6383235477661491, 0.865..."
1003,7,duplex,32,15,"[0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, ...","[[[0.3288472665422168]], [[0.2742222827749568]...","[0.34120574280152965, 0.2862787054152567, 0.14...","[0.2550552512862242, 0.21951327638403645, 0.13...","[[[0.13749103057227297]], [[0.4002356727625992...","[0.32979126523476765, 0.5282026375796491, 0.87...","[0.5088095251827011, 0.6361662513256137, 0.870..."
1004,7,duplex,33,15,"[0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, ...","[[[0.31899940269434396]], [[0.2963685175394461...","[0.3475412695803789, 0.2932045126811001, 0.148...","[0.2530147005777024, 0.2202278399571709, 0.127...","[[[0.18825842209866583]], [[0.2993482885366328...","[0.31846943090620616, 0.5149194356893776, 0.87...","[0.5166375519064024, 0.633793679630871, 0.8763..."
1005,1,simplisma,43,15,"[1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, ...","[[[0.3074635835530016]], [[0.2639122194868142]...","[0.39434585180051374, 0.3125834600778755, 0.18...","[0.26011462823507725, 0.2148971562443983, 0.14...","[[[0.2613893618374048]], [[0.4558142222602922]...","[0.1681354492225713, 0.4773270163981661, 0.808...","[0.4891293528035856, 0.6513073973207526, 0.843..."
