# Train ML

> A collection of machine learning tools

In [None]:
#| default_exp train

In [None]:
#| hide
import sys
sys.path.append("/notebooks/katlas")
from nbdev.showdoc import *
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from sklearn import set_config
from katlas.core import Data
from katlas.feature import *
from fastbook import *
import xgboost as xgb
import matplotlib.pyplot as plt
from scipy.stats import spearmanr,pearsonr
from sklearn.model_selection import *
from pathlib import Path
from sklearn.metrics import mean_squared_error
import math
from scipy.stats import spearmanr, pearsonr
from joblib import dump, load
from sklearn.linear_model import *
from sklearn.svm import *
from sklearn.ensemble import *

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [None]:
#| export
from sklearn import set_config
set_config(transform_output="pandas")

## Splitter

In [None]:
#| export
def get_splits(df, # df contains info for split
               stratified=None, # colname to make stratified kfold; sampling from different groups
               group=None, # colname to make group kfold; test and train are from different groups
               nfold=5,
              seed=123):
    # train_idx, test_idx = None, None
    
    splits = []
    if stratified is not None and group is None:
        kf = StratifiedKFold(nfold, shuffle=True, random_state=seed)
        for split in kf.split(df.index, df[stratified]):
            splits.append(split)
        print(kf)
        print(f'# kinase {stratified} in train set: {df.loc[split[0]][stratified].unique().shape[0]}')
        print(f'# kinase {stratified} in test set: {df.loc[split[1]][stratified].unique().shape[0]}')
        
    elif group is not None and stratified is None:
        kf = GroupKFold(nfold)
        for split in kf.split(df.index, groups=df[group]):
            splits.append(split)
            
        print(kf)
        print(f'# kinase {group} in train set: {df.loc[split[0]][group].unique().shape[0]}')
        print(f'# kinase {group} in test set: {df.loc[split[1]][group].unique().shape[0]}')
        
    elif stratified is not None and group is not None:
        kf = StratifiedGroupKFold(nfold, shuffle=True, random_state=seed)
        for split in kf.split(df.index, groups=df[group], y=df[stratified]):
            splits.append(split)
            
        print(kf)    
        print(f'# kinase {stratified} in train set: {df.loc[split[0]][stratified].unique().shape[0]}')
        print(f'# kinase {stratified} in test set: {df.loc[split[1]][stratified].unique().shape[0]}')
    else:
        raise ValueError("Either 'stratified' or 'group' argument must be provided.")
        
        
    print('---------------------------')
    print(f'# kinase in train set: {df.loc[split[0]].kinase.unique().shape[0]}')
    
    print('---------------------------')
    print(f'# kinase in test set: {df.loc[split[1]].kinase.unique().shape[0]}')
    print('---------------------------')
    print(f'test set: {df.loc[split[1]].kinase.unique()}')
    
    return splits

In [None]:
df = pd.read_parquet('train/scaled_t5.parquet')

info = Data.get_kinase_info_full()
info = info.query('in_paper == 1')
info = df[['kinase']].merge(info,'left')

In [None]:
df.head(2)

Unnamed: 0,kinase,-1A,-1C,-1D,-1E,-1F,-1G,-1H,-1I,-1K,-1L,-1M,-1N,-1P,-1Q,-1R,-1S,-1T,-1V,-1W,-1Y,-1t,-1y,-2A,-2C,-2D,-2E,-2F,-2G,-2H,-2I,-2K,-2L,-2M,-2N,-2P,-2Q,-2R,-2S,-2T,-2V,-2W,-2Y,-2t,-2y,-3A,-3C,-3D,-3E,-3F,-3G,-3H,-3I,-3K,-3L,-3M,-3N,-3P,-3Q,-3R,-3S,-3T,-3V,-3W,-3Y,-3t,-3y,-4A,-4C,-4D,-4E,-4F,-4G,-4H,-4I,-4K,-4L,-4M,-4N,-4P,-4Q,-4R,-4S,-4T,-4V,-4W,-4Y,-4t,-4y,-5A,-5C,-5D,-5E,-5F,-5G,-5H,-5I,-5K,-5L,-5M,-5N,-5P,-5Q,-5R,-5S,-5T,-5V,-5W,-5Y,-5t,-5y,1A,1C,1D,1E,1F,1G,1H,1I,1K,1L,1M,1N,1P,1Q,1R,1S,1T,1V,1W,1Y,1t,1y,2A,2C,2D,2E,2F,2G,2H,2I,2K,2L,2M,2N,2P,2Q,2R,2S,2T,2V,2W,2Y,2t,2y,3A,3C,3D,3E,3F,3G,3H,3I,3K,3L,3M,3N,3P,3Q,3R,3S,3T,3V,3W,3Y,3t,3y,4A,4C,4D,4E,4F,4G,4H,4I,4K,4L,4M,4N,4P,4Q,4R,4S,4T,4V,4W,4Y,4t,4y,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,...,525,526,527,528,529,530,531,532,533,534,535,536,537,538,539,540,541,542,543,544,545,546,547,548,549,550,551,552,553,554,555,556,557,558,559,560,561,562,563,564,565,566,567,568,569,570,571,572,573,574,575,576,577,578,579,580,581,582,583,584,585,586,587,588,589,590,591,592,593,594,595,596,597,598,599,600,601,602,603,604,605,606,607,608,609,610,611,612,613,614,615,616,617,618,619,620,621,622,623,624,625,626,627,628,629,630,631,632,633,634,635,636,637,638,639,640,641,642,643,644,645,646,647,648,649,650,651,652,653,654,655,656,657,658,659,660,661,662,663,664,665,666,667,668,669,670,671,672,673,674,675,676,677,678,679,680,681,682,683,684,685,686,687,688,689,690,691,692,693,694,695,696,697,698,699,700,701,702,703,704,705,706,707,708,709,710,711,712,713,714,715,716,717,718,719,720,721,722,723,724,725,726,727,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767,768,769,770,771,772,773,774,775,776,777,778,779,780,781,782,783,784,785,786,787,788,789,790,791,792,793,794,795,796,797,798,799,800,801,802,803,804,805,806,807,808,809,810,811,812,813,814,815,816,817,818,819,820,821,822,823,824,825,826,827,828,829,830,831,832,833,834,835,836,837,838,839,840,841,842,843,844,845,846,847,848,849,850,851,852,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867,868,869,870,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894,895,896,897,898,899,900,901,902,903,904,905,906,907,908,909,910,911,912,913,914,915,916,917,918,919,920,921,922,923,924,925,926,927,928,929,930,931,932,933,934,935,936,937,938,939,940,941,942,943,944,945,946,947,948,949,950,951,952,953,954,955,956,957,958,959,960,961,962,963,964,965,966,967,968,969,970,971,972,973,974,975,976,977,978,979,980,981,982,983,984,985,986,987,988,989,990,991,992,993,994,995,996,997,998,999,1000,1001,1002,1003,1004,1005,1006,1007,1008,1009,1010,1011,1012,1013,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
0,AAK1,0.946369,0.750092,0.396778,0.359132,1.338569,1.735594,1.691909,1.068228,3.021196,1.607434,1.36108,1.065893,2.791637,0.9763,2.269404,0.646482,0.773602,0.962921,1.08113,1.90951,0.297487,0.438644,1.15445,1.228171,1.7257,1.364512,0.720698,0.605965,1.215434,0.810457,1.259543,1.221001,2.501417,1.733176,0.485449,3.253871,1.189908,1.252112,1.849408,0.729657,0.620334,0.855378,0.449146,1.042213,2.015915,1.149534,0.489356,0.667119,0.877569,0.922324,0.716192,1.079547,0.931704,0.950197,1.409066,1.048676,1.954319,1.252821,1.04648,1.476805,1.492007,1.154782,0.765315,0.741093,0.545392,0.493679,1.524235,1.176241,0.716604,1.207458,1.121429,1.385753,1.109553,1.34073,1.745601,1.600687,1.496126,0.926268,1.152337,1.352589,1.542429,1.251705,1.448485,1.334892,0.870723,1.152977,0.715856,0.731355,0.657475,1.028856,0.371345,0.35441,0.984259,0.567495,0.767758,3.599866,0.60621,2.300098,2.000824,0.637286,1.667971,1.297654,2.213788,1.083113,1.821688,2.203011,0.72892,2.204366,0.466312,1.415361,0.945665,0.658737,0.311431,0.297835,0.39064,21.898553,0.485795,0.309391,0.588938,0.396854,0.374869,0.516644,1.405348,0.418832,0.755945,0.601498,1.951873,0.325021,0.492216,0.433639,0.373723,0.304418,1.595721,1.418367,1.175191,1.114377,1.033853,1.594455,1.318701,0.787707,1.227595,0.880516,1.02346,1.774478,1.121415,1.641377,1.647485,1.802696,6.862167,1.154721,1.218073,0.97679,0.871068,0.872911,1.473434,1.828538,1.024875,0.940439,1.08807,2.336723,1.705319,0.9815,1.870017,1.236928,1.106091,1.86118,1.724726,1.577127,2.280772,1.856363,7.939573,1.544267,1.216757,1.348203,0.848044,0.907864,1.521485,1.383726,0.917663,1.078137,1.233385,1.653679,1.318901,0.978864,1.957998,1.086193,1.093726,1.394213,1.479688,1.496959,2.186988,1.282003,2.625764,0.995587,1.947299,1.227889,0.592576,0.637314,0.075806,0.123474,0.04422,0.024155,-0.020981,0.018646,-0.051666,-0.073914,-0.000199,-0.029678,0.01133,-0.004208,-0.042145,0.049408,0.072815,0.02742,0.046356,-0.019913,-0.00901,-0.006969,-0.075195,0.005646,0.066467,-0.017838,-0.037292,-0.032257,0.033264,-0.050903,-0.030777,-0.001095,-0.010994,0.090637,-0.113647,-0.023529,-0.119446,0.000334,0.041321,0.052979,-0.010551,0.049133,-0.024902,-0.056641,-0.027313,-0.003847,-0.008194,-0.022583,-0.014557,-0.013809,-0.039062,-0.016296,-0.016205,-0.004978,-0.030731,-0.064026,0.076538,0.010185,-0.002413,-0.044708,-0.038177,0.012962,-0.03949,0.033417,-0.023499,-0.050323,-0.005211,-0.008736,-0.013977,-0.00338,0.013565,-0.028488,-0.001527,-0.019012,-0.112305,-0.09668,-0.009743,0.071838,0.090027,-0.016739,0.007217,0.047211,0.008972,-6.5e-05,0.000295,0.037201,-0.016739,-0.041321,0.048187,-0.033905,-0.012993,0.029816,0.044495,0.012627,-0.021835,-0.022415,0.004944,-0.015312,0.036835,-0.001993,-0.017105,-0.024994,-0.052826,-0.011192,0.009392,0.037384,0.042358,0.059448,-0.062073,-0.028381,-0.026581,-0.100098,-0.006435,0.008713,-0.029282,-0.01371,0.010635,0.031494,0.017883,-0.02327,-0.017365,-0.00988,-0.030853,-0.003613,0.005623,-0.008797,-0.028259,0.038147,-0.006405,0.002558,0.037567,0.014839,0.064209,0.013557,-0.01371,-0.054413,0.007378,-0.037842,0.011246,-0.008766,-0.006462,-0.011086,0.08136,-0.042206,0.000128,0.012131,0.011658,-0.009361,0.045837,-0.016434,-0.034546,-0.026688,0.024216,-0.047302,0.002134,-0.007587,-0.019867,-0.022842,0.036072,0.037354,0.00321,-0.015015,-0.000678,-0.01889,-0.006477,-0.022598,0.005695,0.045135,-0.01413,-0.05188,-0.030502,0.015106,0.000329,-0.043365,0.024353,0.030045,-0.001305,0.004642,-0.005768,-0.017014,-0.004181,-0.049713,-0.056427,0.009155,-0.045776,0.007156,-0.01828,-0.0047,0.001636,0.020081,-0.088074,-0.04895,0.0336,-0.02507,-0.00634,0.000706,0.01973,0.04007,0.007988,0.007843,0.018036,0.006222,-0.080261,0.015335,-0.063904,-0.111755,0.002302,-0.04541,0.139771,0.020996,-0.023407,0.030762,-0.000692,0.022141,0.019226,0.014305,-0.021378,0.020935,0.033112,-0.078491,0.028824,0.060852,0.015511,-0.068359,0.023346,0.006447,0.010971,0.021866,-0.052643,0.034119,0.002172,-0.039032,0.007835,0.076233,0.016296,0.003429,-0.03717,-0.004272,-0.016205,-0.022598,-0.032166,0.020111,-0.033417,0.077026,-0.049713,-0.042755,-0.027512,0.031677,0.044647,0.012749,0.001925,-0.03186,0.002102,-0.011528,-0.016998,0.013962,-0.003967,0.004162,0.092957,-0.018204,0.009781,0.00465,-0.070496,-0.024796,0.009285,0.027481,0.002068,-0.005741,0.028275,0.016403,0.029007,-0.042999,-0.013016,-0.033539,-0.022537,-0.120422,-0.008606,-0.00031,0.058533,-0.079346,-0.019516,-0.110107,0.010979,0.061066,-0.001573,-0.020798,-0.050171,-0.000573,0.013115,0.06604,0.020599,0.054962,-0.008514,0.065063,0.04184,-0.017944,0.134644,-0.004711,-0.033722,-0.008011,0.007278,0.0401,...,0.034637,0.012474,-0.01429,0.003443,0.031586,-0.022568,0.012405,0.054504,-0.001762,0.015305,-0.039459,-0.064819,-0.148438,-0.014084,0.042419,-0.022079,-0.013077,0.038849,-0.031586,-0.004677,-0.036346,0.010551,-0.024612,-0.018005,-0.01886,-0.029633,0.014496,-0.003729,0.017487,0.011093,-0.001432,-0.015602,-0.018524,0.010529,-0.027832,-0.049225,-0.009209,-3.1e-05,-0.028702,-0.015594,0.033752,0.015961,-0.005936,-0.02121,0.029053,0.021011,0.020279,0.064697,0.012329,-0.018188,0.036713,-0.005318,-0.007664,0.053284,-0.003691,0.005142,-0.006172,-0.019073,0.005718,0.003899,0.022537,-0.002831,0.034668,0.052643,-0.023148,-0.000988,-0.012932,-0.009491,-0.032379,0.028214,-0.016068,0.015404,0.013412,0.002834,-0.017426,0.043701,0.031464,-0.025864,0.111267,-0.014275,-0.001395,0.0093,-0.066467,-0.036255,-0.005653,0.015358,0.033234,-0.037903,-0.011765,0.109436,0.137451,-0.004978,-0.013557,-0.040253,-0.02092,0.005733,0.002939,-0.06076,0.008247,0.055267,-0.059875,-0.000307,-0.027863,-0.001251,0.016037,0.087402,0.058044,-0.016617,0.005623,0.029648,0.037872,0.006912,-0.009247,-0.01017,0.027267,-0.067322,-0.011871,-0.032593,-0.023834,-0.08905,-0.012207,-0.018814,0.010475,-0.074951,-0.00565,0.030731,-0.022736,-0.029816,-0.021774,0.022705,-0.016571,0.030106,0.032043,0.012131,0.020142,0.012009,-0.031921,0.028488,-0.015404,0.018997,-0.010689,-0.06192,0.027191,0.025986,0.064026,0.042206,0.032043,0.022476,-0.006386,0.033905,-0.069275,-0.056213,-0.039581,-0.072876,-0.024063,0.001897,-0.026001,0.009972,0.024414,-0.019775,0.008499,0.020157,-0.026718,-0.008011,-0.005413,0.025024,0.009827,0.002968,-0.025772,-0.012451,0.000473,0.015701,-0.031464,-0.065674,0.003244,0.007618,0.003496,0.01091,0.03183,0.020996,-0.001505,0.018921,0.03183,-0.02655,-0.017853,0.001337,0.038391,0.070496,0.009651,0.024048,0.008141,0.015228,-0.011986,-0.038666,-0.012329,-0.000697,-0.103027,0.023438,-0.015335,0.026443,-0.080933,0.009361,0.001467,0.001709,0.017639,-0.01651,0.01255,-0.010788,0.028931,-0.058014,0.004837,0.009674,-0.021881,-0.001799,0.034241,0.015686,0.036102,0.005127,-0.022034,0.009987,-0.046265,-0.024734,0.004166,0.027405,-0.0513,0.01088,0.06781,-0.002954,-0.050903,0.008492,0.038055,-0.037231,-0.018356,-0.043121,-0.030258,-0.029678,-0.065063,-0.005772,-0.042999,0.039581,0.014565,0.021393,-0.007717,0.016815,0.033142,0.005249,0.0448,-0.020416,-0.008759,0.063965,0.018723,-0.045959,0.029892,0.01709,0.07605,-0.012802,0.006939,-0.044861,0.035339,0.032104,0.027878,0.031281,-0.026672,0.025879,0.027588,0.064209,-0.016693,-0.106262,0.043396,-0.020554,0.00396,0.039856,0.019638,-0.027756,-0.052643,0.061096,-0.003496,-0.005116,0.015068,0.008629,-0.023254,-0.025604,0.000993,-0.009666,0.013908,0.014221,0.036591,0.027252,-0.019653,0.132812,-0.024612,0.00721,0.022903,0.021088,0.004574,-0.040985,0.041565,0.057495,-0.014626,0.001556,0.007046,0.033417,0.034332,-0.010231,0.003181,0.028168,-0.017349,0.035614,-0.031555,-0.007126,-0.028915,0.005169,0.098511,-0.018799,0.019226,0.003691,0.002443,0.037903,0.003136,-0.010521,0.019135,0.009346,0.000595,0.001385,0.004177,0.009224,-0.025925,-0.017624,-0.019257,-0.032532,0.022842,-0.043121,-0.010902,0.062988,-0.039948,-0.012985,0.026993,0.042328,-0.019272,-0.000315,-0.032928,0.012604,-0.000112,-0.022339,-0.052765,0.010628,0.010719,-0.010895,0.029648,-0.022995,-0.056702,0.018051,0.009293,0.029434,-0.002823,-0.012001,-0.041901,-0.027863,0.025955,0.022141,0.014969,-0.029648,-0.006279,0.021393,0.0336,0.017487,0.015182,0.004604,-0.006794,0.020569,-0.025345,0.013672,0.007732,0.113403,-0.00391,-0.049805,0.053711,-0.037842,-0.01152,-0.054077,-0.0224,0.009041,0.016083,0.016602,-0.037476,-0.012039,-0.022522,0.032288,0.080566,0.00666,0.02298,-0.040619,-0.039673,-0.002926,-0.019974,-0.01107,-0.003387,-0.000525,0.005066,-0.023087,-0.030716,-0.00158,-0.019974,0.039185,0.035187,0.032806,-0.035583,-0.017059,0.000515,-0.033417,0.035339,-0.005024,-0.009666,0.037506,-0.016525,-0.023376,-0.048279,-0.010506,0.037354,-0.009666,-0.024734,0.009277,-0.10791,-0.03656,-0.085999,-0.060303,-0.027557,-0.007263,-0.15918,0.000834,-0.030472,-0.020294,0.010719,0.023392,0.000917,0.014282,0.05835,0.008934,0.027847,-0.002787,-0.062073,0.074097,0.027603,0.004349,-0.022339,0.020584,-0.03717,-0.010078,-0.020508,0.026749,0.033417,-0.036163,0.024567,-0.088989,-0.026703,0.007843,0.022552,-0.023438,-0.010834,0.044189,-0.021942,-0.064331,0.005058,-0.006531,0.065491,0.024567,-0.041321,0.031006,0.023758,-0.022827,-0.011292,-0.045258,-0.048096,-0.013275,-0.014236,-0.012283,-0.020538,0.034363,0.018646,0.130005,-0.04776,0.045197,-0.016388,0.002653,-0.037872,-0.004353,-0.02298,-0.006428,0.00872,-0.042908,-0.033295,-0.037445,0.004528,0.08075,-0.029404,-0.041809,-0.010941,-0.026306,-0.013878
1,ACVR2A,0.817853,2.507437,1.223991,1.287046,1.16328,0.567058,1.064461,0.800036,0.529279,1.26759,1.432364,0.811907,0.672279,0.894928,0.574185,1.33189,1.403282,1.098782,1.38008,1.3243,1.365207,1.801629,0.666375,2.20696,5.403941,6.571803,0.633808,0.628259,0.612366,0.570055,0.445351,0.541286,0.500831,0.849266,0.403309,0.965851,0.633851,1.603416,1.383578,0.631034,0.678183,0.625199,1.270243,0.9056,0.967938,1.510075,1.894097,2.172636,0.930175,1.177731,0.954922,0.906949,0.639496,0.8637,0.928723,0.905888,0.905848,0.916069,0.83071,1.500418,1.552451,0.984317,1.06536,0.991942,2.173145,1.758313,1.04447,1.477509,1.958134,1.636104,1.165014,1.033113,1.037924,0.990884,0.764216,0.978966,1.085357,0.943518,0.876613,0.907915,0.860701,1.45613,1.513285,1.023313,1.419717,1.102758,1.522305,1.282888,1.185107,1.413459,1.803529,1.603117,1.219221,0.977011,1.157957,1.269254,1.036371,1.209452,1.058476,1.088901,0.843088,0.874055,0.96537,1.194614,1.284031,1.214249,1.631005,1.173103,1.59081,1.544484,0.630237,1.76377,1.639788,3.125961,0.907378,0.560338,0.68183,1.144896,0.318509,0.850475,0.969826,0.621665,0.331491,1.197223,0.405319,1.166477,1.310639,1.290526,0.936397,0.945533,4.528119,2.73444,1.123892,1.454261,1.240273,1.028571,0.98542,1.360871,1.12465,0.916751,0.806405,0.675687,0.880837,0.864695,0.990975,1.007216,0.937349,1.165704,1.027286,1.001489,0.984974,1.108442,1.106946,0.879231,0.903146,1.168926,1.061076,1.485729,1.099421,1.010464,1.006445,0.946048,0.788339,1.045034,1.005513,0.894916,1.142814,0.851548,0.660316,1.34452,1.221341,0.967054,1.084382,1.117174,0.885409,1.391548,0.887266,1.310805,1.139014,1.140323,0.932088,0.967855,1.02019,1.002589,0.938573,0.911637,1.178528,0.990293,1.350281,1.09631,0.874794,0.987733,1.154086,0.918866,1.425005,1.031622,1.252413,1.048924,0.029007,0.103027,0.061066,0.005985,-0.013107,0.037109,-0.012222,-0.060669,0.004623,-0.005775,0.030151,0.015556,-0.030258,0.020645,0.089172,0.001034,0.022202,-0.037781,-0.020264,-0.018387,-0.075317,0.010017,0.061615,0.009468,-0.007374,-0.026215,0.011856,-0.051971,-0.011002,0.011406,-0.017639,0.062866,-0.112061,-0.022629,-0.121094,0.013161,0.045197,0.055756,-0.016953,0.063843,-0.036804,-0.033844,-0.048645,0.01226,-0.008698,-0.011284,0.009644,0.007507,-0.014481,-0.011894,-0.016037,0.018707,-0.038605,-0.044189,0.062927,0.005104,-0.012032,-0.051758,0.010742,0.010849,-0.052826,0.024704,-0.030273,-0.038025,-0.025436,-0.001678,0.018768,-0.008682,0.041565,-0.04538,0.013229,-0.012329,-0.120361,-0.066589,-0.008636,0.063599,0.03717,-0.042847,0.029266,0.03125,0.016785,0.005226,0.002333,0.060181,-0.017105,-0.023331,0.049042,-0.037231,0.007866,0.040131,0.047455,0.004395,-0.009598,-0.001114,-0.007904,-0.037231,0.008499,-0.000888,-0.030136,-0.037476,-0.052856,-0.000756,-0.005741,0.045807,0.041565,0.051422,-0.04892,-0.015472,-0.025818,-0.051117,0.014206,-0.001417,-0.05011,-0.018021,0.013222,0.046265,0.004776,-0.025894,-0.063599,-0.02739,-0.064697,0.007751,-0.008469,-0.02655,-0.041595,0.015083,-0.016663,0.011642,0.050415,0.008987,0.070435,0.001126,-0.010536,-0.060394,-0.019119,-0.058105,0.039398,-0.027863,-0.015793,-1e-05,0.106628,-0.041626,-0.001277,0.021667,0.0149,0.007595,0.050415,-0.018692,0.006332,-0.008736,0.006306,-0.040161,0.005966,-0.022339,-0.018936,-0.03183,0.059845,0.055511,-0.024002,-0.012108,-0.010918,-0.012375,-0.011925,-0.020248,-0.005135,0.054901,0.029831,-0.007904,-0.011169,-0.011452,-0.022263,-0.034912,0.030289,0.06192,-0.019638,0.017365,-0.017258,-0.040222,-0.01519,-0.035492,-0.04361,-0.011116,-0.035309,-0.003164,-0.012306,0.011658,0.002655,0.026154,-0.077576,-0.036346,0.024033,-0.01265,-0.028931,-0.001731,0.023331,0.023087,4.3e-05,0.028259,-0.01384,-0.013672,-0.050201,-0.046631,-0.049774,-0.080811,0.008888,-0.031891,0.156128,0.032654,-0.009933,0.045349,-0.001379,0.035522,0.000117,0.009468,-0.018143,0.008736,0.032471,-0.084656,0.015961,0.044067,0.005638,-0.051819,-0.020645,0.013084,0.039124,0.017761,-0.053101,0.039307,-0.023193,-0.018478,0.030579,0.047272,-0.017441,0.023651,-0.041199,-0.016815,-0.046265,-0.021515,-0.033813,0.012535,-0.073425,0.052338,-0.049255,-0.04599,-0.037445,0.045807,0.027542,0.034821,0.006687,-0.022446,0.024567,0.010307,-0.011086,-0.001564,0.033112,0.021667,0.094055,-0.004128,-0.009567,-0.024109,-0.089722,-0.041046,-0.016479,0.00655,0.028366,0.002405,0.02182,0.040436,0.007057,-0.046631,0.001845,-0.033386,0.005833,-0.120422,-0.007431,0.008469,0.054413,-0.081726,-0.0336,-0.124817,-0.012527,0.041992,0.008461,-0.018738,-0.04776,-0.01918,-0.020584,0.042175,0.002954,0.060883,-0.034058,0.050049,0.012253,-0.012169,0.131714,-0.024689,-0.023529,-0.040985,0.022903,0.054626,...,0.017517,0.003677,-0.027863,0.008614,0.041473,-0.009636,0.0159,0.058533,-0.024796,-0.005993,-0.029587,-0.03244,-0.114685,0.001323,-0.005829,0.008759,-0.003279,0.052002,-0.031494,0.017227,-0.0336,0.021637,-0.01133,-0.016342,-0.010193,-0.030258,-0.005199,0.007851,0.014267,0.014771,0.001123,-0.039062,0.004692,-0.002415,-0.006908,0.004269,0.003597,-0.021301,-0.029556,-0.030838,0.009674,0.022308,0.003401,-0.028381,0.033264,0.046692,0.001947,0.049377,-0.013664,-0.035889,0.052551,-0.00853,-0.013184,0.050507,-0.013542,-0.018219,-0.009415,-0.018768,-0.004803,0.000857,0.02478,0.016327,0.040649,0.043915,-0.006977,-0.005981,-0.033295,-0.007645,-0.008698,0.046722,-0.004078,0.01004,0.014282,0.015091,-0.02272,0.056549,-0.009117,-0.02037,0.071716,-0.00737,0.020721,-0.000519,-0.042542,-0.028442,0.008171,0.028061,0.015556,-0.038086,-0.006931,0.093872,0.140015,-0.001546,-0.015869,-0.013054,-0.018036,0.029678,0.016281,-0.045166,0.022568,0.033051,-0.040619,0.018173,-0.017242,-0.001512,0.024475,0.077026,0.0242,0.021362,-0.007648,0.025787,0.058136,0.049561,0.006523,-0.001993,0.022766,-0.042786,-0.013763,-0.027252,-0.014206,-0.064758,0.006783,0.021469,0.023071,-0.063171,0.003593,0.0168,-0.016647,-0.046387,-0.033203,-0.009918,-0.011314,0.024658,0.030792,-0.00695,-0.005722,-0.010719,-0.011566,0.041656,-0.057617,-0.000494,0.027451,-0.042725,0.025284,0.025696,0.049866,0.074036,0.033264,0.014633,0.008896,0.01239,-0.096069,-0.051788,-0.052002,-0.055267,-0.052094,0.003223,-0.0233,0.0103,-0.016083,-0.011383,-0.011497,0.041809,-0.008011,-0.032166,0.004097,0.010521,0.028412,-0.00415,-0.012848,-0.008392,-0.020096,-0.000323,-0.029404,-0.063904,-0.018417,0.01152,0.000278,0.018524,0.059235,0.034302,-5.7e-05,0.019424,0.050018,-0.02771,0.001359,-0.013458,0.05423,0.060486,-0.001789,0.001125,-0.028793,-0.018051,-0.045868,0.002676,0.018509,-0.026108,-0.089355,0.025543,-0.005939,0.013565,-0.077209,-0.020737,0.003222,0.003731,0.028091,-0.013428,0.003681,-0.014816,0.003206,-0.058167,-0.032867,0.006481,0.003855,0.011108,-0.010193,0.012085,0.054352,0.007072,-0.018524,0.013954,-0.053497,-0.011314,-0.000908,0.038147,-0.026672,0.023788,0.064087,-0.038361,-0.042999,0.013,0.037781,-0.019974,0.00708,-0.0354,-0.042511,-0.017136,-0.075623,0.001192,-0.049133,0.049286,0.013275,0.010361,0.009712,0.020615,0.03299,-0.057404,0.049377,0.002516,-0.00433,0.078064,-0.010246,-0.048096,0.047607,0.035248,0.100891,-0.009888,0.036713,-0.046173,-0.004818,0.016235,0.048553,0.026718,-0.0383,0.031708,0.027313,0.021179,-0.015823,-0.112244,0.044891,0.002653,0.008331,0.017685,0.007504,-0.033966,-0.049072,0.047882,-0.022293,0.014549,0.030823,-0.008141,-0.028366,-0.021011,0.009583,-0.015762,0.016602,0.026093,0.010475,0.014526,-0.042145,0.128906,-0.028778,0.008614,0.004406,0.022491,0.01104,-0.031311,0.031525,0.050446,-0.008499,-0.010483,0.001303,0.018295,0.044159,0.00531,-0.0112,0.041718,-0.017181,0.032135,-0.035492,-0.021072,-0.039917,0.039032,0.09845,-0.009743,0.013229,0.013329,-0.030533,0.049408,0.012199,-0.009766,0.036163,0.003252,-0.003979,0.011383,0.016357,0.016907,-0.044464,-0.011024,-0.004772,-0.045959,0.01503,-0.037262,-0.053345,0.092346,-0.036346,-0.020126,0.031052,0.047943,-0.040222,0.026901,-0.016922,0.000418,-0.021362,-0.01857,-0.016174,-0.017776,0.013237,-0.017532,0.018051,-0.025192,-0.059509,0.043793,-0.014153,0.035889,-0.003798,-0.029831,-0.010002,-0.05072,0.034058,-0.007217,-0.00824,0.020447,-0.001945,0.027832,0.034912,0.013718,0.02298,0.009476,-0.012558,0.001751,-0.030533,0.002953,0.025116,0.112671,0.003078,-0.048492,0.045502,-0.050262,0.001666,-0.062225,-0.009262,-0.005882,0.030899,0.043884,-0.011253,-0.027573,-0.032104,0.028931,0.084106,0.02298,0.02829,-0.028061,-0.043518,0.005112,-0.02626,-0.011543,0.00481,-0.029877,-0.015312,-0.031982,-0.018036,-0.004185,-0.021362,0.016541,0.022842,0.044769,-0.005585,-0.003082,-0.005203,-0.028503,0.021194,-0.021393,-0.001122,0.033295,0.003952,-0.012161,-0.041229,0.006432,-0.006126,0.004238,-0.036255,0.01226,-0.111145,0.003555,-0.092346,-0.081665,-0.007359,-0.014404,-0.174438,0.017197,-0.044891,-0.006252,0.018982,0.013046,-0.016479,0.013336,0.04953,-0.00256,0.005241,0.01284,-0.079407,0.064758,0.00679,0.00745,-0.009621,0.026291,-0.052307,0.002789,-0.011368,0.040161,-0.008919,-0.027435,0.000631,-0.081116,-0.007957,0.00407,-0.007835,0.012711,-0.003593,0.022308,-0.006233,-0.044067,0.005409,-0.026657,0.047485,0.031586,-0.012573,0.021088,0.00146,-0.000111,-0.013992,-0.04422,-0.033234,-0.005283,0.008911,0.00544,0.021072,0.025116,0.018661,0.128662,-0.051025,0.024353,0.017044,0.007118,-0.011482,0.019058,-0.028366,-0.011894,0.016907,-0.035828,-0.026596,-0.039948,-0.017471,0.05484,-0.013222,-0.036804,-0.011765,-0.046844,-0.013596


In [None]:
info.head(3)

Unnamed: 0,kinase,ID_coral,uniprot,ID_HGNC,group,family,subfamily_coral,subfamily,in_paper,pseudo,species_paper,uniprot_paper,range,human_uniprot_sequence,full_sequence_paper,sequence,category,category_big,cluster,length,kinasecom_domain,hdbscan
0,AAK1,AAK1,Q2M2I8,AAK1,Other,NAK,,NAK,1,0,human,Q2M2I8,aa27-365,MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAP...,MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAP...,TSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIA,NAK,NAK,18.0,339,VTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYF,3.0
1,ACVR2A,ACTR2,P27037,ACVR2A,TKL,STKR,STKR2,STKR2,1,0,human,P27037,aa162-end,MGAAAKLAFAVFLISCSSGAILGRSETQECLFFNANWEKDRTNQTGVEPCYGDKDKRRHCFATWKNISGSIEIVKQGCWLDDINCYDRTDCVEKKDSPEVYFCCCEGNMCNEKFSYFPEMEVTQPTSNPVTPKPPYYNILLYSLVPLMLIAGIVICAFWVYRHHKMAYPPVLVPTQDPGPPPPSPLLGLKPLQLLEVKARGRFGCVWKAQLLNEYVAVKIFPIQDKQSWQNEYEVYSLPGMKHENILQFIGAEKRGTSVDVDLWLITAFHEKGSLSDFLKANVVSWNELCHIAETMARGLAYLHEDIPGLKDGHKPAISHRDIKSKNVLLKNNLTACIADFGLALKFEAGKSAGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELASRCTAADGPVDEYMLPFEEEIGQHPSLEDMQEVVVHKKKRPVLRDYWQKHAGMAMLCETIEECWDHDAEARLSAGCVGERITQMQRLTNIITTEDIVTVVTMVTNVDFPPKESSL,MGAAAKLAFAVFLISCSSGAILGRSETQECLFFNANWEKDRTNQTGVEPCYGDKDKRRHCFATWKNISGSIEIVKQGCWLDDINCYDRTDCVEKKDSPEVYFCCCEGNMCNEKFSYFPEMEVTQPTSNPVTPKPPYYNILLYSLVPLMLIAGIVICAFWVYRHHKMAYPPVLVPTQDPGPPPPSPLLGLKPLQLLEVKARGRFGCVWKAQLLNEYVAVKIFPIQDKQSWQNEYEVYSLPGMKHENILQFIGAEKRGTSVDVDLWLITAFHEKGSLSDFLKANVVSWNELCHIAETMARGLAYLHEDIPGLKDGHKPAISHRDIKSKNVLLKNNLTACIADFGLALKFEAGKSAGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELASRCTAADGPVDEYMLPFEEEIGQHPSLEDMQEVVVHKKKRPVLRDYWQKHAGMAMLCETIEECWDHDAEARLSAGCVGERITQMQRLTNIITTEDIVTVVTMVTNVDFPPKESSL,RHHKMAYPPVLVPTQDPGPPPPSPLLGLKPLQLLEVKARGRFGCVWKAQLLNEYVAVKIFPIQDKQSWQNEYEVYSLPGMKHENILQFIGAEKRGTSVDVDLWLITAFHEKGSLSDFLKANVVSWNELCHIAETMARGLAYLHEDIPGLKDGHKPAISHRDIKSKNVLLKNNLTACIADFGLALKFEAGKSAGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELASRCTAADGPVDEYMLPFEEEIGQHPSLEDMQEVVVHKKKRPVLRDYWQKHAGMAMLCETIEECWDHDAEARLSAGCVGERITQMQRLTNIITTEDIVTVVTMVTNVDFPPKESSL,TGFBR,acidophilic,23.0,352,LQLLEVKARGRFGCVWKAQLLNEYVAVKIFPIQDKQSWQNEYEVYSLPGMKHENILQFIGAEKRGTSVDVDLWLITAFHEKGSLSDFLKANVVSWNELCHIAETMARGLAYLHEDIPGLKDGHKPAISHRDIKSKNVLLKNNLTACIADFGLALKFEAGKSAGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELASRCTAADGPVDEYMLPFEEEIGQHPSLEDMQEVVVHKKKRPVLRDYWQKHAGMAMLCETIEECWDHDAEARLSAGCVGERI,3.0
2,ACVR2B,ACTR2B,Q13705,ACVR2B,TKL,STKR,STKR2,STKR2,1,0,human,Q13705,aa161-end,MTAPWVALALLWGSLCAGSGRGEAETRECIYYNANWELERTNQSGLERCEGEQDKRLHCYASWRNSSGTIELVKKGCWLDDFNCYDRQECVATEENPQVYFCCCEGNFCNERFTHLPEAGGPEVTYEPPPTAPTLLTVLAYSLLPIGGLSLIVLLAFWMYRHRKPPYGHVDIHEDPGPPPPSPLVGLKPLQLLEIKARGRFGCVWKAQLMNDFVAVKIFPLQDKQSWQSEREIFSTPGMKHENLLQFIAAEKRGSNLEVELWLITAFHDKGSLTDYLKGNIITWNELCHVAETMSRGLSYLHEDVPWCRGEGHKPSIAHRDFKSKNVLLKSDLTAVLADFGLAVRFEPGKPPGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELVSRCKAADGPVDEYMLPFEEEIGQHPSLEELQEVVVHKKMRPTIKDHWLKHPGLAQLCVTIEECWDHDAEARLSAGCVEERVSLIRRSVNGTTSDCLVSLVTSVTNVDLPPKESSI,MTAPWVALALLWGSLCAGSGRGEAETRECIYYNANWELERTNQSGLERCEGEQDKRLHCYASWRNSSGTIELVKKGCWLDDFNCYDRQECVATEENPQVYFCCCEGNFCNERFTHLPEAGGPEVTYEPPPTAPTLLTVLAYSLLPIGGLSLIVLLAFWMYRHRKPPYGHVDIHEDPGPPPPSPLVGLKPLQLLEIKARGRFGCVWKAQLMNDFVAVKIFPLQDKQSWQSEREIFSTPGMKHENLLQFIAAEKRGSNLEVELWLITAFHDKGSLTDYLKGNIITWNELCHVAETMSRGLSYLHEDVPWCRGEGHKPSIAHRDFKSKNVLLKSDLTAVLADFGLAVRFEPGKPPGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELVSRCKAADGPVDEYMLPFEEEIGQHPSLEELQEVVVHKKMRPTIKDHWLKHPGLAQLCVTIEECWDHDAEARLSAGCVEERVSLIRRSVNGTTSDCLVSLVTSVTNVDLPPKESSI,RHRKPPYGHVDIHEDPGPPPPSPLVGLKPLQLLEIKARGRFGCVWKAQLMNDFVAVKIFPLQDKQSWQSEREIFSTPGMKHENLLQFIAAEKRGSNLEVELWLITAFHDKGSLTDYLKGNIITWNELCHVAETMSRGLSYLHEDVPWCRGEGHKPSIAHRDFKSKNVLLKSDLTAVLADFGLAVRFEPGKPPGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELVSRCKAADGPVDEYMLPFEEEIGQHPSLEELQEVVVHKKMRPTIKDHWLKHPGLAQLCVTIEECWDHDAEARLSAGCVEERVSLIRRSVNGTTSDCLVSLVTSVTNVDLPPKESSI,TGFBR,acidophilic,23.0,352,LQLLEIKARGRFGCVWKAQLMNDFVAVKIFPLQDKQSWQSEREIFSTPGMKHENLLQFIAAEKRGSNLEVELWLITAFHDKGSLTDYLKGNIITWNELCHVAETMSRGLSYLHEDVPWCRGEGHKPSIAHRDFKSKNVLLKSDLTAVLADFGLAVRFEPGKPPGDTHGQVGTRRYMAPEVLEGAINFQRDAFLRIDMYAMGLVLWELVSRCKAADGPVDEYMLPFEEEIGQHPSLEELQEVVVHKKMRPTIKDHWLKHPGLAQLCVTIEECWDHDAEARLSAGCVEERV,3.0


Make sure the info and training df share same index

In [None]:
splits = get_splits(info, stratified = 'hdbscan') # use info to split the data

StratifiedKFold(n_splits=5, random_state=123, shuffle=True)
# kinase hdbscan in train set: 8
# kinase hdbscan in test set: 8
---------------------------
# kinase in train set: 243
---------------------------
# kinase in test set: 60
---------------------------
test set: ['ACVR2A' 'ACVR2B' 'AKT1' 'AMPKA2' 'ANKRD3' 'AURA' 'BMPR1B' 'BMPR2' 'BRSK2' 'CAMK1D' 'CAMK2D' 'CDK17' 'CDK7' 'CHAK2' 'CK1A2' 'CLK3' 'DCAMKL1' 'DYRK3' 'ERK2' 'ERK7' 'GRK3' 'GRK6' 'HASPIN' 'HIPK3'
 'IRAK1' 'IRAK4' 'IRE1' 'JNK1' 'LOK' 'LRRK2' 'MEK1' 'MEKK1' 'MEKK2' 'MNK2' 'MOK' 'MRCKA' 'MSK1' 'MST4' 'MTOR' 'NIK' 'NUAK1' 'NUAK2' 'P38B' 'PAK3' 'PKACG' 'PKCE' 'PKCG' 'PRKD1' 'PRKD2' 'PRPK' 'RIPK1'
 'RSK2' 'SIK' 'SRPK1' 'TAO2' 'TGFBR1' 'TLK2' 'ULK1' 'WNK1' 'YSK4']


In [None]:
len(splits)

5

In [None]:
#| export
def split_data(df, # dataframe of values
               feat_col, # feature columns
               target_col, # target columns
               split # one of the split in splits
              ):
    
    X_train = df.loc[split[0]][feat_col]
    y_train = df.loc[split[0]][target_col]
    
    X_test = df.loc[split[1]][feat_col]
    y_test = df.loc[split[1]][target_col]
    
    return X_train, y_train, X_test, y_test

In [None]:
feat_col = df.columns[199:]
target_col = df.columns[1:199]

In [None]:
# feat_col = ['position'] + df.columns.tolist()[5:]
# target_col = ['target']

In [None]:
# split data
X_train, y_train, X_test, y_test = split_data(df, feat_col, target_col, splits[0])

In [None]:
print(f'X_train shape is : {X_train.shape}')
print(f'y_train shape is : {y_train.shape}')
print(f'X_test shape is : {X_test.shape}')
print(f'y_test shape is : {y_test.shape}')

X_train shape is : (242, 1024)
y_train shape is : (242, 198)
X_test shape is : (61, 1024)
y_test shape is : (61, 198)


## ML Trainer

In [None]:
#| export
def train_ml(df, # dataframe of values
             feat_col, # feature columns
             target_col, # target columns
             split, # one split in splits
             model,  # a sklearn models
             save = None, # file (.joblib) to save, e.g. 'model.joblib'
             params={},
            ):
    
    " Train one split of data. Need to specify dataframe, feature columns, target columns, split, and which sklearn models to use"
    
    # split data
    X_train, y_train, X_test, y_test = split_data(df, feat_col, target_col, split)
    
    # Fit the model
    model.fit(X_train, y_train, **params) # better convert y_train to numpy array and flatten
    print(model)
    
    if save is not None:
        # Save the model to a file
        # joblib.dump(model, save)
        dump(model, save)
    
    # Predict
    y_pred = model.predict(X_test) # X_test is dataframe, y_pred is numpy array
    y_pred = pd.DataFrame(y_pred,index=y_test.index, columns = y_test.columns)
    
    
    return y_test, y_pred #two dataframes

In [None]:
y_test, y_pred = train_ml(df, feat_col, target_col, splits[0],LinearRegression())

LinearRegression()


In [None]:
y_test.head(2) # ground truth

Unnamed: 0,-1A,-1C,-1D,-1E,-1F,-1G,-1H,-1I,-1K,-1L,-1M,-1N,-1P,-1Q,-1R,-1S,-1T,-1V,-1W,-1Y,-1t,-1y,-2A,-2C,-2D,-2E,-2F,-2G,-2H,-2I,-2K,-2L,-2M,-2N,-2P,-2Q,-2R,-2S,-2T,-2V,-2W,-2Y,-2t,-2y,-3A,-3C,-3D,-3E,-3F,-3G,-3H,-3I,-3K,-3L,-3M,-3N,-3P,-3Q,-3R,-3S,-3T,-3V,-3W,-3Y,-3t,-3y,-4A,-4C,-4D,-4E,-4F,-4G,-4H,-4I,-4K,-4L,-4M,-4N,-4P,-4Q,-4R,-4S,-4T,-4V,-4W,-4Y,-4t,-4y,-5A,-5C,-5D,-5E,-5F,-5G,-5H,-5I,-5K,-5L,-5M,-5N,-5P,-5Q,-5R,-5S,-5T,-5V,-5W,-5Y,-5t,-5y,1A,1C,1D,1E,1F,1G,1H,1I,1K,1L,1M,1N,1P,1Q,1R,1S,1T,1V,1W,1Y,1t,1y,2A,2C,2D,2E,2F,2G,2H,2I,2K,2L,2M,2N,2P,2Q,2R,2S,2T,2V,2W,2Y,2t,2y,3A,3C,3D,3E,3F,3G,3H,3I,3K,3L,3M,3N,3P,3Q,3R,3S,3T,3V,3W,3Y,3t,3y,4A,4C,4D,4E,4F,4G,4H,4I,4K,4L,4M,4N,4P,4Q,4R,4S,4T,4V,4W,4Y,4t,4y
4,1.252723,1.24883,0.873898,0.601646,1.047566,1.523919,1.680993,0.566336,1.786166,1.017873,1.297056,2.205684,1.478037,1.473652,2.12154,1.410493,1.146518,0.689077,0.971239,1.81427,0.580574,1.093674,1.357054,2.036879,0.416371,0.417932,0.612234,0.844795,1.129163,0.600633,3.882917,1.076347,0.670114,0.750983,0.528956,0.902214,8.277744,3.094029,2.293587,0.667872,0.600235,0.605753,0.662378,0.583033,1.033506,1.202165,0.421425,0.455981,0.568662,0.693936,1.079918,0.540987,3.777316,0.573559,0.575909,0.601408,0.672721,0.949034,19.928012,1.107675,1.014496,0.544039,0.580602,0.614292,0.684288,0.632425,1.488888,1.494032,0.740843,0.838902,0.907602,1.538154,1.33393,0.857336,2.570105,1.034472,1.151573,1.185649,1.260408,1.279404,3.04361,1.278347,1.277058,0.909421,0.87461,1.041559,0.975093,0.831319,1.317615,1.4301,0.62584,0.646126,1.122382,1.264056,1.248934,0.857198,2.007806,1.010056,1.050676,0.973416,1.234496,1.013656,3.025823,1.135926,1.092143,0.886697,1.093696,1.117178,0.935301,0.975119,0.840923,1.658432,0.518939,0.47776,1.896991,0.910117,1.009849,1.216862,1.086259,1.179772,1.437809,1.169001,0.769496,1.265725,1.198927,1.295663,1.273282,1.076143,1.035844,1.034681,0.823531,0.942467,1.437826,2.002882,0.720787,0.643082,0.752372,1.521882,1.965847,0.821316,1.488958,0.975453,0.894401,1.372951,1.034323,0.984906,1.578773,4.191876,1.996739,1.028776,0.808967,0.982034,1.230224,1.26434,1.296462,1.892505,0.716739,0.711202,1.047595,1.544028,1.693936,0.751384,2.174711,1.045462,0.895297,1.119572,1.246545,1.174084,2.219066,1.569695,1.308552,0.904983,1.081797,1.080408,0.984368,1.225548,1.237134,1.606234,0.823663,0.796645,0.911979,1.644231,1.546106,0.920043,2.630636,0.986489,1.090272,1.419044,2.203153,1.485343,2.254299,1.756481,1.77573,0.857565,0.944533,1.02433,1.24821,0.950655
5,1.294072,1.153273,0.993227,0.519128,1.161256,1.899052,1.748892,0.364165,2.444212,1.107083,1.654374,2.678062,1.689569,1.531418,2.411177,1.527736,1.266838,0.596234,0.960301,2.166737,0.413546,1.167212,1.506838,1.504885,0.246259,0.234297,0.527537,0.815138,1.923967,0.451609,4.974254,1.273318,0.64494,0.615123,0.466501,0.706624,12.1467,3.648219,2.983792,0.534407,0.417579,0.537104,0.484827,0.433791,1.074513,0.920832,0.262625,0.278554,0.338557,0.665245,1.131245,0.263748,3.034166,0.332554,0.361946,0.486353,0.820732,1.01315,23.083859,1.046747,0.807574,0.296889,0.441138,0.448077,0.515564,0.449338,1.515622,1.229415,0.667259,0.758091,0.829927,1.412425,1.34551,0.699594,2.970424,0.876682,1.131393,1.145324,1.443855,1.386143,4.1828,1.399942,1.155929,0.868439,0.867052,0.987613,0.831115,0.729234,1.586096,1.221776,0.58173,0.547084,0.841509,1.329727,1.479166,0.722431,2.699287,0.697416,0.812174,0.919091,1.500741,0.99819,3.716144,1.941354,1.325112,0.867442,1.061004,0.942139,0.683293,0.773865,0.787243,1.547058,0.433654,0.371891,2.560955,0.883633,1.096996,1.467718,1.097079,1.294964,1.80792,1.147164,0.464697,1.021469,1.243349,1.603971,1.381737,1.397411,1.127963,1.187843,0.651683,0.733107,1.609683,2.148273,0.795194,0.547055,0.846036,1.829739,2.9227,1.026199,1.853814,1.401466,1.098968,1.626895,1.058891,1.016294,2.588813,4.498148,2.240851,1.454272,1.30234,1.153029,0.656788,0.931439,1.244741,1.687986,0.711769,0.666914,1.18251,1.69192,2.086424,1.346346,2.725641,1.251046,1.03046,1.30798,1.208911,1.408288,2.672972,1.787365,1.267646,1.019906,1.60012,1.248233,0.978806,1.06156,1.196977,1.307189,0.928225,0.786032,0.957953,1.701085,1.756645,1.095542,2.856858,1.080406,1.135535,1.623719,4.243776,1.553547,2.43524,2.120943,1.740623,1.045402,1.443644,1.105475,1.01347,0.819062


In [None]:
y_pred.head(2) # predicted values

Unnamed: 0,-1A,-1C,-1D,-1E,-1F,-1G,-1H,-1I,-1K,-1L,-1M,-1N,-1P,-1Q,-1R,-1S,-1T,-1V,-1W,-1Y,-1t,-1y,-2A,-2C,-2D,-2E,-2F,-2G,-2H,-2I,-2K,-2L,-2M,-2N,-2P,-2Q,-2R,-2S,-2T,-2V,-2W,-2Y,-2t,-2y,-3A,-3C,-3D,-3E,-3F,-3G,-3H,-3I,-3K,-3L,-3M,-3N,-3P,-3Q,-3R,-3S,-3T,-3V,-3W,-3Y,-3t,-3y,-4A,-4C,-4D,-4E,-4F,-4G,-4H,-4I,-4K,-4L,-4M,-4N,-4P,-4Q,-4R,-4S,-4T,-4V,-4W,-4Y,-4t,-4y,-5A,-5C,-5D,-5E,-5F,-5G,-5H,-5I,-5K,-5L,-5M,-5N,-5P,-5Q,-5R,-5S,-5T,-5V,-5W,-5Y,-5t,-5y,1A,1C,1D,1E,1F,1G,1H,1I,1K,1L,1M,1N,1P,1Q,1R,1S,1T,1V,1W,1Y,1t,1y,2A,2C,2D,2E,2F,2G,2H,2I,2K,2L,2M,2N,2P,2Q,2R,2S,2T,2V,2W,2Y,2t,2y,3A,3C,3D,3E,3F,3G,3H,3I,3K,3L,3M,3N,3P,3Q,3R,3S,3T,3V,3W,3Y,3t,3y,4A,4C,4D,4E,4F,4G,4H,4I,4K,4L,4M,4N,4P,4Q,4R,4S,4T,4V,4W,4Y,4t,4y
4,0.880791,0.622719,0.96055,0.700666,1.153082,1.564817,1.056182,0.856165,2.141996,1.751584,1.589074,1.355316,1.320749,1.206049,2.315694,1.194068,0.868294,0.926967,0.850907,1.126364,0.721614,1.673116,1.140022,0.522095,0.594378,1.666814,0.854763,0.661071,1.290889,0.658254,3.097906,1.223099,0.988903,0.856534,1.00486,1.477508,7.259325,1.932906,1.623866,0.65779,0.624887,0.641342,1.948765,5.038746,1.082604,0.659705,0.815372,0.980333,0.727818,1.084473,1.400426,0.67864,3.807357,0.815253,0.60041,0.680822,0.93572,1.061968,10.941114,1.491273,1.030589,0.72728,0.671362,0.66128,1.525312,2.452796,1.233271,0.813517,0.619934,0.992625,1.043517,1.19998,1.401803,1.019787,2.054327,1.10908,1.257301,0.929303,1.121664,1.155277,2.710951,1.388477,1.488443,1.17209,1.118719,1.063543,1.648121,1.598501,1.122952,0.797036,0.658023,0.663709,1.361683,0.86273,1.304848,1.229473,1.567571,1.561108,1.063051,0.857515,1.20066,1.005065,2.137386,1.623627,1.831581,1.313838,1.04585,1.246158,1.298929,1.71521,0.662742,0.340296,0.043465,0.576465,3.101009,0.906655,1.285566,1.634424,1.160999,1.248203,1.464651,1.062558,3.968082,0.985305,1.307601,1.682637,1.185116,1.709194,1.507619,1.366714,1.796392,2.048392,1.593327,1.102411,1.378234,0.357327,1.073161,1.492891,1.59301,1.244568,1.641541,1.238622,0.822334,1.188205,1.378617,1.181676,3.078151,3.686029,2.05898,1.334777,0.382799,0.609451,1.281502,0.541642,1.285201,0.893367,0.892318,0.337786,1.384147,1.674513,1.384488,1.017258,2.380152,1.188938,1.155136,1.226855,1.802301,1.279146,2.279682,2.776567,1.584899,1.132408,0.840852,1.149841,1.152554,1.255384,1.371209,0.916174,1.22815,0.930996,1.05192,1.599864,1.578821,0.795891,2.198844,0.963892,0.980861,1.498177,1.948126,1.704162,1.934293,2.542421,1.580968,1.088194,1.482994,0.998058,3.797018,0.912856
5,0.929538,0.858041,1.000899,0.58166,1.043249,1.839039,1.365918,0.516358,2.073915,1.331279,1.562119,1.919175,1.276413,1.375619,2.547575,1.392574,1.078016,0.694548,0.958035,1.597075,0.477856,1.152724,1.397301,1.213398,0.128213,0.475064,0.376965,0.784733,1.132425,0.578571,3.133225,1.418803,0.71336,0.804273,0.225411,0.904714,9.750872,2.875341,2.452484,0.699616,0.346472,0.221861,1.175775,1.158697,0.926652,0.906074,0.480399,0.860973,0.528511,0.510841,1.292207,0.415131,3.373493,0.613372,0.504264,0.531332,0.619454,0.815627,13.341736,1.605012,1.128888,0.420068,0.605516,0.696449,1.201428,1.959073,1.201881,0.943274,0.867961,1.07328,1.044476,1.18035,1.436667,1.032511,2.404923,1.193812,1.223354,1.071209,0.924396,1.152283,3.067859,1.826459,1.721899,1.133352,1.548273,1.055927,1.951309,1.364031,1.097885,1.067915,0.879728,0.911181,1.213839,1.103597,1.305292,1.129559,1.799078,1.28666,1.108686,0.910434,1.167149,0.978838,2.780817,2.203076,1.483477,0.97439,0.97109,1.190445,1.083757,1.079758,0.217424,1.118824,0.22835,0.404896,3.125341,1.976914,1.104548,1.997904,0.998664,1.271509,1.969985,0.949745,0.632168,1.209517,1.319231,1.426457,1.527416,1.807208,1.511742,1.181756,0.733854,0.737861,1.190224,1.484794,0.565462,0.458537,0.813866,1.346953,2.039813,1.35684,1.489837,1.368404,0.964276,1.092393,0.672661,0.949666,3.101026,3.495048,3.253386,1.550002,0.965943,1.062792,0.601654,1.853414,1.159375,1.496792,0.937857,0.695079,1.807525,1.598659,2.118884,1.331331,1.117564,1.532495,1.448284,1.568873,0.768246,1.363604,2.489908,1.990832,1.817188,1.264662,1.454764,1.528902,0.898745,1.650651,1.348573,1.228396,1.188106,0.882571,1.262785,1.710181,1.903268,1.15811,3.048677,1.523252,1.364202,1.751578,2.470292,1.912025,3.085753,2.146339,1.848763,1.172421,1.716329,1.097679,1.270089,0.930688


In [None]:
#| export
def predict_ml(df, # Dataframe that contains features
               feat_col, # feature columns
               model_pth # models.joblib
              ):
    
    test = df[feat_col]
    
    model = load(model_pth)
    
    pred = model.predict(test)
    
    pred_df = pd.DataFrame(pred)
    
    return pred_df

In [None]:
# pred = predict_ml(X_test, feat_col, 'model.joblib')

In [None]:
# pred.head()

In [None]:
#| export
def score_all(target, pred):
    
    "Calculate the overall correlation between two dataframes; need to have same index and columns"
    
    # Calculate RMSE
    mse = mean_squared_error(target, pred)
    # rmse = math.sqrt(mse)
    print(f'mse is {mse:.4f}')

    # Calculate the Spearman correlation coefficient
    spearman_corr, _ = spearmanr(target.values.ravel(), pred.values.ravel())
    print(f"Spearman correlation coefficient: {spearman_corr:.4f}")

    # Calculate the Pearson correlation coefficient
    pearson_corr, _ = pearsonr(target.values.ravel(), pred.values.ravel())
    print(f"Pearson correlation coefficient: {pearson_corr:.4f} ")
    
    # return mse,spearman_corr, pearson_corr

### Cross-Validation

In [None]:
#| export
def train_cv(df, # dataframe of values
             feat_col, # feature columns
             target_col,  # target columns
             splits, # splits
             model, # sklearn model
             save_name = None, # model name to be saved, e.g., 'LR'
             params = {}
            ):
    OOF = []
    for fold, split in enumerate(splits):
        print(f'------ fold: {fold} --------')
        
        if save_name is not None: 
            target, pred = train_ml(df, feat_col, target_col, split, model, f'models/{save_name}_fold.joblib',params)
        else:
            target, pred = train_ml(df, feat_col, target_col, split, model, params=params)

        score_all(target,pred)
        OOF.append(pred)
        
    oof_df = pd.concat(OOF).sort_index()
    
    return oof_df

In [None]:
oof = train_cv(df,feat_col, target_col, splits, LinearRegression())

------ fold: 0 --------
LinearRegression()
mse is 1.1492
Spearman correlation coefficient: 0.4699
Pearson correlation coefficient: 0.5758 
------ fold: 1 --------
LinearRegression()
mse is 1.4311
Spearman correlation coefficient: 0.4784
Pearson correlation coefficient: 0.5117 
------ fold: 2 --------
LinearRegression()
mse is 0.9352
Spearman correlation coefficient: 0.5135
Pearson correlation coefficient: 0.6516 
------ fold: 3 --------
LinearRegression()
mse is 1.0260
Spearman correlation coefficient: 0.4896
Pearson correlation coefficient: 0.6412 
------ fold: 4 --------
LinearRegression()
mse is 0.9092
Spearman correlation coefficient: 0.4865
Pearson correlation coefficient: 0.6101 


In [None]:
oof.head(2)

Unnamed: 0,-1A,-1C,-1D,-1E,-1F,-1G,-1H,-1I,-1K,-1L,-1M,-1N,-1P,-1Q,-1R,-1S,-1T,-1V,-1W,-1Y,-1t,-1y,-2A,-2C,-2D,-2E,-2F,-2G,-2H,-2I,-2K,-2L,-2M,-2N,-2P,-2Q,-2R,-2S,-2T,-2V,-2W,-2Y,-2t,-2y,-3A,-3C,-3D,-3E,-3F,-3G,-3H,-3I,-3K,-3L,-3M,-3N,-3P,-3Q,-3R,-3S,-3T,-3V,-3W,-3Y,-3t,-3y,-4A,-4C,-4D,-4E,-4F,-4G,-4H,-4I,-4K,-4L,-4M,-4N,-4P,-4Q,-4R,-4S,-4T,-4V,-4W,-4Y,-4t,-4y,-5A,-5C,-5D,-5E,-5F,-5G,-5H,-5I,-5K,-5L,-5M,-5N,-5P,-5Q,-5R,-5S,-5T,-5V,-5W,-5Y,-5t,-5y,1A,1C,1D,1E,1F,1G,1H,1I,1K,1L,1M,1N,1P,1Q,1R,1S,1T,1V,1W,1Y,1t,1y,2A,2C,2D,2E,2F,2G,2H,2I,2K,2L,2M,2N,2P,2Q,2R,2S,2T,2V,2W,2Y,2t,2y,3A,3C,3D,3E,3F,3G,3H,3I,3K,3L,3M,3N,3P,3Q,3R,3S,3T,3V,3W,3Y,3t,3y,4A,4C,4D,4E,4F,4G,4H,4I,4K,4L,4M,4N,4P,4Q,4R,4S,4T,4V,4W,4Y,4t,4y
0,0.985162,0.888783,0.650083,0.470425,1.453643,1.411065,1.417298,0.920627,2.458418,1.262223,1.503112,1.263942,1.93209,1.094558,2.081244,0.887029,0.975983,0.804702,1.167784,1.839061,0.638824,0.700638,1.519925,1.726393,0.852782,0.801714,0.881199,1.15369,1.232933,0.7997,1.638345,1.084671,1.812468,1.129932,1.980593,2.331065,3.62846,2.17957,2.055677,0.804283,0.745818,0.849442,0.43838,0.356337,1.532111,1.332136,-0.003621,0.093122,1.056355,0.994271,0.765361,1.246958,1.485118,0.922334,1.50004,0.976653,2.062295,0.916558,2.301953,1.300145,1.380309,0.954043,0.830219,0.979465,-0.563371,0.223295,1.332589,1.256111,0.576807,0.750933,1.111451,1.440795,0.986112,1.213366,1.581206,1.266443,1.149016,0.770259,1.584879,1.085441,1.753005,1.132172,1.211919,1.258714,1.021737,1.229256,0.113481,0.321791,0.917841,1.537907,0.286272,0.358621,1.253529,1.125725,1.174129,2.255844,0.890841,1.484774,1.409381,0.707453,1.358767,0.934388,2.043394,0.870646,1.149838,1.348048,1.448632,2.203252,0.1951,0.809193,0.869675,1.023177,0.879266,1.140842,0.90017,12.898273,0.293536,0.637649,0.826959,0.889004,0.515031,0.717413,2.993179,0.504604,1.040342,1.100091,1.443135,0.65512,0.522426,0.383007,1.11758,0.152828,1.181576,1.318348,1.444544,1.271309,1.417365,1.583491,1.750619,0.762573,1.204545,1.001189,1.161987,1.892462,1.151114,1.348847,1.52031,0.783146,4.07017,0.697158,1.372698,1.836292,1.146508,0.218075,1.132494,1.749113,0.993369,1.179805,1.380494,1.840563,1.459534,0.878186,2.463665,1.031125,0.989562,1.503613,1.817318,1.366825,1.705683,0.983024,4.478827,1.090159,1.341009,1.405892,0.912008,0.514723,1.056797,1.635825,1.028232,1.506865,1.458827,1.367372,1.281341,0.972839,1.955609,0.860715,1.225984,1.580088,1.861708,1.434961,2.290511,1.074263,2.419911,0.785828,1.784527,1.556943,2.697593,0.759031
1,0.860345,1.056054,-0.021727,0.73315,1.060984,0.664783,1.325944,0.880354,1.344422,1.336841,1.505645,1.433817,1.58512,1.343445,0.987193,0.875764,1.29881,0.934836,0.738402,1.175065,0.547596,0.800354,0.191479,1.09815,2.503217,6.07791,0.292255,0.382276,0.740177,0.624708,1.862604,0.778881,0.866448,0.608232,1.03954,1.084842,2.223496,2.411831,1.508868,0.650832,-0.140831,-0.134454,1.640363,0.989736,0.835254,0.890961,1.304374,0.466219,0.66281,0.589999,0.87961,0.713971,1.750399,0.402877,0.596547,1.199719,1.222576,1.138276,8.787333,1.50566,1.283365,0.780118,0.627713,0.746536,2.194788,0.426225,1.028474,1.172266,1.228348,0.953209,0.832176,1.299877,0.848722,0.658576,1.356633,0.591675,0.88546,0.993212,1.686173,0.978212,1.916378,1.488129,1.2823,0.903033,1.146155,0.883491,0.569094,0.927771,0.816491,1.03075,0.77943,0.381243,0.993179,1.161713,1.900947,0.974397,1.467198,0.884084,0.885221,0.855353,1.224898,0.920743,2.338325,0.137174,0.95527,0.749451,1.62605,0.972339,1.071668,1.101694,0.610453,1.300101,0.447875,1.836936,0.78038,2.999725,0.988816,1.302223,0.400237,1.372531,0.995811,1.213511,3.300998,1.177625,0.566577,1.33968,1.65279,1.072424,0.664591,0.744962,1.973112,1.848135,1.252339,1.400457,1.808257,0.860343,0.74932,2.415468,1.691803,0.460353,0.818736,0.179111,0.397072,1.005807,1.805958,1.172563,0.290194,2.229342,-0.937808,0.780147,0.660023,1.123837,1.017262,2.576274,1.049215,1.089985,0.971367,0.330528,0.7796,2.011519,1.396939,0.713822,1.528657,0.395893,0.544543,1.122328,2.057665,1.147255,1.855701,2.503104,1.97706,0.891791,0.494675,1.073295,1.122163,1.097552,1.227413,1.348919,1.086443,1.12678,0.613686,1.494505,1.201192,1.024059,1.019874,0.067005,0.875971,1.298144,2.068491,1.156066,0.624889,1.640546,1.730281,0.806572,1.174241,1.209895,1.304122,1.074957


In [None]:
oof.shape

(303, 198)

### Score

In [None]:
y = df[target_col]

score_all(y, oof)

mse is 1.0909
Spearman correlation coefficient: 0.4875
Pearson correlation coefficient: 0.5931 


In [None]:
#| export
def score_each(target, 
               pred,
               absolute=False, # If absolute, then will get absolute value of spearman and pearson
              ):
    "Calculate spearman and pearson per row"
    
    # pred.columns = target.columns
    sp = target.corrwith(pred,axis=1,method='spearman')
    
    pear = target.corrwith(pred,axis=1,method='pearson')
    
    df = pd.DataFrame(np.stack([sp,pear]).T,columns = ['spearman','pearson'])
    if absolute ==True:
        df = df.apply(abs)
        
    print(f'average spearman for each row is {df.spearman.mean()}')
    print(f'average pearson for each row is {df.pearson.mean()}')
    return df

In [None]:
cor = score_each(y,oof)

average spearman for each row is 0.5016147601026477
average pearson for each row is 0.5844748750730399


In [None]:
cor.head()

Unnamed: 0,spearman,pearson
0,0.726783,0.912949
1,0.191249,0.335579
2,0.241742,0.358883
3,0.853006,0.964424
4,0.549874,0.824573


## XGB

In [None]:
#| export
def xgb_trainer(df,
                feature_col,
                target_col,
                test_index=None,
                xgb_params = { 
                            'max_depth':7, #from 4 to 7
                            'learning_rate':0.001, #from 0.001
                            'subsample':0.8,
                            'colsample_bytree':1, # from 0.2 to 1, because need to take all features
                            'eval_metric':'rmse',
                            'objective':'reg:squarederror',
                            'tree_method':'gpu_hist',
                            'predictor':'gpu_predictor',
                            'random_state':123
                        },
                model_file='xgb_model.bin',
                split_seed = 123, # seed of random split
               ):
    
    X = df[feature_col]
    y = df[target_col]
    
    print(f'xgb params is: {xgb_params}')
    
    if test_index is None:
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=split_seed)
    else:
        X_train,y_train = X.loc[~X.index.isin(test_index)],y.loc[~X.index.isin(test_index)]
        X_test, y_test = X.loc[test_index],y.loc[test_index]

        
    print(X_train.shape,y_train.shape,X_test.shape, y_test.shape)
    print(y_test.index)
    #prepare matrix for xgb
    dtrain = xgb.DMatrix(X_train, y_train)
    dtest = xgb.DMatrix(X_test, y_test)
    
    model = xgb.train(xgb_params, 
            dtrain=dtrain,
            evals=[(dtrain,'train'),(dtest,'valid')],
            num_boost_round=9999,
            early_stopping_rounds=100,
            verbose_eval=100,)
    
    # Save the model
    path = Path(model_file)
    
    # Make a directory if not exists
    path.parent.mkdir(exist_ok=True)
        
    model.save_model(model_file)
    print(f'Model saved to {model_file}')
    
    # Prepare the pred/target df
    pred = model.predict(dtest)
    
    out = np.vstack([np.ravel(y_test),np.ravel(pred)]).T
    pred_df = pd.DataFrame(out,index=y_test.index, columns = ['target','pred'] )
    
    spearman_corr, _ = spearmanr(pred_df.target, pred_df.pred)
    print(f'Spearman correlation: {spearman_corr:.2f}')
    pearson_corr, p_value = pearsonr(pred_df.target, pred_df.pred)
    print(f'Pearson correlation: {pearson_corr:.2f}')


    
    fig, ax = plt.subplots()
    ax.scatter(pred_df.target, pred_df.pred)
    ax.set_xlabel('True values')
    ax.set_ylabel('Predicted values')
    ax.set_title('Scatter plot of true versus predicted values')
    plt.show()
    plt.close()
    
    
    dd = model.get_score(importance_type='gain')
    gain = pd.DataFrame({'feature':dd.keys(),f'gain_importance':dd.values()}).set_index('feature').sort_values(by='gain_importance',ascending=False)
    gain[:10].plot.barh()
    plt.show()
    plt.close()
        
    dd = model.get_score(importance_type='weight')
    weight = pd.DataFrame({'feature':dd.keys(),f'weight_importance':dd.values()}).set_index('feature').sort_values(by='weight_importance',ascending=False)
    weight[:10].plot.barh()
    plt.show()
    plt.close()
    
    return pred_df, gain, weight

In [None]:
# xgb_trainer(df,feat_col, target_col)

In [None]:
#| export
def xgb_predict(df, # a dataframe that contains ID and features for prediction
                feature_col, #feature column name
                ID_col = "ID", #ID column name
                model_file='xgb_model.bin'):
    # Load the XGBoost model
    model = xgb.Booster()
    model.load_model(model_file)
    
    # Prepare data for prediction
    X = df[feature_col]
    dtest = xgb.DMatrix(X)
    
    # Make predictions
    preds = model.predict(dtest)
    
    # Combine predictions with IDs into a DataFrame
    result_df = pd.DataFrame({ID_col: df[ID_col], 'preds': preds})
    
    return result_df

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()