# Train DL

> A collection of deep learning tools via Fastai

## Setup

In [1]:
#| default_exp dl

In [2]:
#| hide
import sys
sys.path.append("/notebooks/katlas")
from nbdev.showdoc import *
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
#| export
from fastbook import *
import fastcore.all as fc,torch.nn.init as init
from fastai.callback.training import GradientClip
from torch.utils.data import WeightedRandomSampler

# katlas
from katlas.core import *
from katlas.feature import *
from katlas.train import *

# sklearn
from sklearn.model_selection import *
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr,pearsonr

## Utils

In [4]:
#| export
def seed_everything(seed=123):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
seed_everything()

## Load Data

In [6]:
# read training data
df = pd.read_parquet('train_data/combine_t5_kd.parquet').reset_index()

# read data contains info for split
info_df = Data.get_kinase_info_full().query('pseudo!="1"') # get non-pseudo kinase

# merge info with training data
info = df[['kinase']].merge(info_df)
info.head()

# splits
splits = get_splits(info,stratified='group')
split0 = splits[0]


# column name of feature and target
feat_col = df.columns[df.columns.str.startswith('T5_')]
target_col = df.columns[~df.columns.isin(feat_col)][1:]

StratifiedKFold(n_splits=5, random_state=123, shuffle=True)
# kinase group in train set: 9
# kinase group in test set: 9
---------------------------
# kinase in train set: 312
---------------------------
# kinase in test set: 78
---------------------------
test set: ['EPHA3' 'FES' 'FLT3' 'FYN' 'EPHB1' 'EPHB3' 'FER' 'EPHB4' 'FLT4' 'FGFR1' 'EPHA5' 'TEK' 'DDR2' 'ZAP70' 'LIMK1' 'ULK3' 'JAK1' 'WEE1' 'TESK1' 'MAP2K3' 'AMPKA2' 'ATM' 'CAMK1D' 'CAMK2D' 'CAMK4' 'CAMKK1'
 'CK1D' 'CK1E' 'DYRK2' 'DYRK4' 'HGK' 'IKKE' 'JNK2' 'JNK3' 'KHS1' 'MAPKAPK5' 'MEK2' 'MSK2' 'NDR1' 'NEK6' 'NEK9' 'NIM1' 'NLK' 'OSR1' 'P38A' 'P38B' 'P90RSK' 'PAK1' 'PERK' 'PKCH' 'PKCI' 'PKN1' 'ROCK2'
 'RSK2' 'SIK' 'STLK3' 'TAK1' 'TSSK1' 'ALPHAK3' 'BMPR2' 'CDK10' 'CDK13' 'CDK14' 'CDKL5' 'GCN2' 'GRK4' 'IRE1' 'KHS2' 'MASTL' 'MLK4' 'MNK1' 'MRCKA' 'PRPK' 'QSK' 'SMMLCK' 'SSTK' 'ULK2' 'VRK1']


## Dataset

In [7]:
#| export
class GeneralDataset:
    def __init__(self, 
                 df, # a dataframe of values
                 feat_col, # feature columns
                 target_col=None # Will return test set for prediction if target col is None
                ):
        "A general dataset that can be applied to any dataframe"
        
        self.test = False if target_col is not None else True
        
        self.X = df[feat_col].values 
        self.y = df[target_col].values if not self.test else None
        
        self.len = df.shape[0]

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        X = torch.Tensor(self.X[index])
        if self.test:
            return X
        else:
            y = torch.Tensor(self.y[index])
            return X, y

In [8]:
# dataset
ds = GeneralDataset(df,feat_col,target_col)

In [9]:
len(ds)

390

In [10]:
dl = DataLoader(ds, batch_size=64, shuffle=True)

In [11]:
#| export
def get_sampler(info,col):
    
    "For imbalanced data, get higher weights for less-represented samples"
    
    # get value counts
    group_counts = info[col].value_counts()
    
    # to reduce the difference through log
    # group_counts = group_counts.apply(lambda x: np.log(x+1.01))
    
    weights = 1. / group_counts[info[col]]

    sample_weights = torch.from_numpy(weights.to_numpy())
    sample_weights = torch.clamp_min(sample_weights,0.01)

    sampler = WeightedRandomSampler(sample_weights, len(sample_weights),replacement=True)
    
    return sampler

In [12]:
sampler = get_sampler(info,'subfamily')

In [13]:
# dataloader
dl = DataLoader(ds, batch_size=64, sampler=sampler)

In [14]:
xb,yb = next(iter(dl))

xb.shape,yb.shape

(torch.Size([64, 1024]), torch.Size([64, 210]))

## Models

### MLP

In [15]:
#| export
def MLP_1(num_features, 
          num_targets,
          hidden_units = [512, 218],
          dp = 0.2):
    
    # Start with the first layer from num_features to the first hidden layer
    layers = [
        nn.Linear(num_features, hidden_units[0]),
        nn.BatchNorm1d(hidden_units[0]),
        nn.Dropout(dp),
        nn.PReLU()
    ]
    
    # Loop over hidden units to create intermediate layers
    for i in range(len(hidden_units) - 1):
        layers.extend([
            nn.Linear(hidden_units[i], hidden_units[i+1]),
            nn.BatchNorm1d(hidden_units[i+1]),
            nn.Dropout(dp),
            nn.PReLU()
        ])
    
    # Add the output layer
    layers.append(nn.Linear(hidden_units[-1], num_targets))
    
    model = nn.Sequential(*layers)
    
    return model

In [16]:
n_feature = len(feat_col)
n_target = len(target_col)

In [17]:
model = MLP_1(n_feature, n_target)

In [18]:
model(xb)

tensor([[-0.1115, -0.3755, -0.3818,  ..., -0.1483, -0.0387, -0.1111],
        [ 0.8555,  0.9352, -0.9642,  ..., -0.4723,  0.7757, -0.0121],
        [ 0.3422,  0.3537, -0.1441,  ...,  0.5467, -0.4535,  0.2103],
        ...,
        [-0.4287,  0.6751,  0.1797,  ...,  0.0192,  0.0692, -0.0573],
        [-0.0206, -0.1953,  0.7445,  ..., -0.2206, -0.1188,  0.4579],
        [ 0.2342, -0.0243,  0.4630,  ...,  0.8393,  0.5747, -0.6881]], grad_fn=<AddmmBackward0>)

### CNN1D

***Version 1***

In [19]:
#| export
class CNN1D_1(Module):
    
    def __init__(self, 
                 num_features, # this does not matter, just for format
                 num_targets):

        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.flatten = Flatten()
        self.fc1 = nn.Linear(in_features = int(8 * num_features/4), out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=num_targets)

    def forward(self, x):
        x = x.unsqueeze(1) # need shape (bs, 1, num_features) for CNN
        x = self.pool1(nn.functional.relu(self.conv1(x)))
        x = self.pool2(nn.functional.relu(self.conv2(x)))
        # x = torch.flatten(x, 1)
        x = self.flatten(x)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [20]:
model = CNN1D_1(n_feature, n_target)

In [21]:
model(xb)

tensor([[ 0.0193,  0.0690,  0.0138,  ..., -0.0428, -0.0026,  0.0840],
        [ 0.0203,  0.0693,  0.0136,  ..., -0.0422, -0.0023,  0.0846],
        [ 0.0198,  0.0703,  0.0148,  ..., -0.0424, -0.0029,  0.0839],
        ...,
        [ 0.0197,  0.0694,  0.0147,  ..., -0.0429, -0.0019,  0.0841],
        [ 0.0193,  0.0687,  0.0146,  ..., -0.0429, -0.0017,  0.0843],
        [ 0.0191,  0.0692,  0.0148,  ..., -0.0425, -0.0028,  0.0834]], grad_fn=<AddmmBackward0>)

***Version 2***

In [22]:
#| export
def init_weights(m, leaky=0.):
    "Initiate any Conv layer with Kaiming norm."
    if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d)): init.kaiming_normal_(m.weight, a=leaky)

In [23]:
#| export
def lin_wn(ni,nf,dp=0.1,act=nn.SiLU):
    "Weight norm of linear."
    layers =  nn.Sequential(
            nn.BatchNorm1d(ni),
            nn.Dropout(dp),
            nn.utils.weight_norm(nn.Linear(ni, nf)) )
    if act: layers.append(act())
    return layers

In [24]:
#| export
def conv_wn(ni, nf, ks=3, stride=1, padding=1, dp=0.1,act=nn.ReLU):
    "Weight norm of conv."
    layers =  nn.Sequential(
        nn.BatchNorm1d(ni),
        nn.Dropout(dp),
        nn.utils.weight_norm(nn.Conv1d(ni, nf, ks, stride, padding)) )
    if act: layers.append(act())
    return layers

In [25]:
#| export
class CNN1D_2(nn.Module):
    
    def __init__(self, ni, nf, amp_scale = 16):
        super().__init__()

        cha_1,cha_2,cha_3 = 256,512,512
        hidden_size = cha_1*amp_scale

        cha_po_1 = hidden_size//(cha_1*2)
        cha_po_2 = (hidden_size//(cha_1*4)) * cha_3
        
        self.lin = lin_wn(ni,hidden_size)
        
        # bs, 256, 16
        self.view = View(-1,cha_1,amp_scale)
        
        self.conv1 = nn.Sequential(
            conv_wn(cha_1, cha_2, ks=5, stride=1, padding=2, dp=0.1),
            nn.AdaptiveAvgPool1d(output_size = cha_po_1),
            conv_wn(cha_2, cha_2, ks=3, stride=1, padding=1, dp=0.1))
        
        self.conv2 = nn.Sequential(
            conv_wn(cha_2, cha_2, ks=3, stride=1, padding=1, dp=0.3),
            conv_wn(cha_2, cha_3, ks=5, stride=1, padding=2, dp=0.2))
        
        self.head = nn.Sequential(
            nn.MaxPool1d(kernel_size=4, stride=2, padding=1),
            nn.Flatten(),
            lin_wn(cha_po_2,nf,act=None) )


    def forward(self, x):
        # amplify features to 4096
        x = self.lin(x)
        
        # reshape to bs,256,16 for conv1d
        x = self.view(x) 

        x = self.conv1(x)
        
        x_s = x  # for skip connection (multiply)
        x = self.conv2(x)
        x = x * x_s

        # Final block
        x = self.head(x)

        return x

In [26]:
model = CNN1D_2(n_feature,n_target).apply(init_weights)

In [27]:
model(xb)

tensor([[-0.5740,  0.0151, -0.0819,  ...,  0.2636,  0.3405, -0.1404],
        [-0.6800,  0.5530, -0.0958,  ..., -0.3752, -0.6124,  0.7171],
        [ 0.4427, -0.3204, -0.3243,  ..., -0.2290,  0.1070,  0.1504],
        ...,
        [-0.3660, -0.2667, -0.6036,  ..., -0.3130,  0.5462, -0.0055],
        [ 0.4511,  0.6824,  0.8659,  ..., -0.0171,  0.2362, -0.3475],
        [-0.0746, -0.1699,  0.6895,  ...,  1.1522, -0.3472,  0.6422]], grad_fn=<AddmmBackward0>)

## DL Trainer

In [29]:
#| export
def train_dl(df, 
            feat_col, 
            target_col,
            split, # tuple of numpy array for split index
            model_func, # function to get pytorch model
             n_epoch = 4, # number of epochs
             bs = 32, # batch size
             lr = 1e-2, # will be useless if lr_find is True
            loss = mse, # loss function
            save = None, # models/{save}.pth
             sampler = None,
             lr_find=False, # if true, will use lr from lr_find
              ):
    "A DL trainer."
    
    train = df.loc[split[0]]
    valid = df.loc[split[1]]
    
    train_ds = GeneralDataset(train, feat_col, target_col)
    valid_ds = GeneralDataset(valid, feat_col, target_col)
    
    n_workers = fc.defaults.cpus

    if sampler is not None:
        
        train_dl = DataLoader(train_ds, batch_size=bs, sampler=sampler,num_workers=n_workers)
        valid_dl = DataLoader(valid_ds, batch_size=bs, sampler=sampler,num_workers=n_workers)
        
        dls = DataLoaders(train_dl, valid_dl)
        
    else:
        
        dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=bs, num_workers=n_workers)
    
    model = model_func()
    
    learn = Learner(dls.cuda(), model.cuda(), loss, 
                    metrics= [PearsonCorrCoef(),SpearmanCorrCoef()],
                    cbs = [GradientClip(1.0)] # prevent overfitting
                   )
    
    if lr_find:
        # get learning rate
        lr = learn.lr_find()
        plt.show()
        plt.close()
        print(lr)

        
    print('lr in training is', lr)
    learn.fit_one_cycle(n_epoch,lr) #cbs = [SaveModelCallback(fname = 'best')] # save best model
    
    if save is not None:
        learn.save(save)
        
    pred,target = learn.get_preds()
    
    pred = pd.DataFrame(pred.detach().cpu().numpy(),index=valid.index,columns=target_col)
    target = pd.DataFrame(target.detach().cpu().numpy(),index=valid.index,columns=target_col)
    
    return target, pred

In [28]:
def get_model():
    return CNN1D_2(n_feature, n_target)

In [30]:
target, pred = train_dl(df, 
                        feat_col, 
                        target_col,
                        split0, 
                        get_model,
                        n_epoch=1,
                        lr = 1e-2,
                        save = 'test')

lr in training is 0.01


epoch,train_loss,valid_loss,pearsonr,spearmanr,time
0,2.194746,2.176958,-0.104578,-0.053141,00:03


In [31]:
score_each(target,pred)

overall MSE: 2.1770
Average Pearson: 0.2149 


(2.176958,
 0.21488270397776432,
       Pearson
 3   -0.442855
 8   -0.490345
 10  -0.401885
 19  -0.428557
 24  -0.383956
 ..        ...
 359 -0.127000
 361  0.005761
 366  0.095977
 367  0.335805
 373 -0.230842
 
 [78 rows x 1 columns])

## DL CV

In [32]:
#| export
@fc.delegates(train_dl)
def train_dl_cv(df, 
                feat_col, 
                target_col, 
                splits, # list of tuples
                model_func, # functions like lambda x: return MLP_1(num_feat, num_target)
                save:str=None,
                **kwargs
                ):
    
    OOF = []
    metrics = []
    
    for fold,split in enumerate(splits):

        print(f'------fold{fold}------')
        
        
        fname=None
        # save best model for each fold
        if save is not None:
            fname = f'{save}_fold{fold}'
        
        # train model
        target, pred = train_dl(df,feat_col,target_col, split, model_func ,save=fname,**kwargs)

        #------------get scores--------------
        # get score metrics
        mse, pearson_avg, _ = score_each(target,pred)
        
        # store metrics in a dictionary for the current fold
        fold_metrics = {
            'fold': fold,
            'mse': mse,
            'pearson_avg': pearson_avg
        }
        metrics.append(fold_metrics)

        OOF.append(pred)
        

    # Concatenate OOF from each fold to a new dataframe
    oof = pd.concat(OOF).sort_index()
    
    # Get metrics into a dataframe
    metrics = pd.DataFrame(metrics)
    
    return oof, metrics

In [33]:
def get_model():
    return CNN1D_2(n_feature, n_target)

In [34]:
oof,metrics = train_dl_cv(df,feat_col,target_col,splits,get_model,n_epoch=1,lr=3e-3)

------fold0------
lr in training is 0.003


epoch,train_loss,valid_loss,pearsonr,spearmanr,time
0,1.165076,0.997911,0.091948,0.058285,00:01


overall MSE: 0.9979
Average Pearson: 0.1634 
------fold1------
lr in training is 0.003


epoch,train_loss,valid_loss,pearsonr,spearmanr,time
0,1.180757,0.992539,0.102852,0.084205,00:01


overall MSE: 0.9925
Average Pearson: 0.1617 
------fold2------
lr in training is 0.003


epoch,train_loss,valid_loss,pearsonr,spearmanr,time
0,1.159264,0.98717,0.119972,0.098912,00:01


overall MSE: 0.9872
Average Pearson: 0.2364 
------fold3------
lr in training is 0.003


epoch,train_loss,valid_loss,pearsonr,spearmanr,time
0,1.184666,1.001829,0.077155,0.047876,00:01


overall MSE: 1.0018
Average Pearson: 0.1415 
------fold4------
lr in training is 0.003


epoch,train_loss,valid_loss,pearsonr,spearmanr,time
0,1.178444,0.992547,0.109969,0.100576,00:01


overall MSE: 0.9925
Average Pearson: 0.2014 


In [35]:
metrics

Unnamed: 0,fold,mse,pearson_avg
0,0,0.997911,0.163423
1,1,0.992539,0.161654
2,2,0.98717,0.236363
3,3,1.001829,0.141464
4,4,0.992547,0.201375


In [36]:
metrics.pearson_avg.mean()

0.18085578818910147

In [37]:
target = df[target_col]
_,_,corr = score_each(target,oof)

overall MSE: 0.9944
Average Pearson: 0.1809 


In [38]:
corr

Unnamed: 0,Pearson
0,-0.183429
1,-0.178564
2,-0.225202
3,-0.117838
4,-0.153463
...,...
385,0.109792
386,0.269238
387,0.079601
388,0.063310


## DL Predict

In [39]:
#| export
def predict_dl(df, 
               feat_col, 
               target_col,
               model, # model architecture
               model_pth, # only name, not with .pth
              ):
    
    "Predict dataframe given a deep learning model"
    
    test_dset = GeneralDataset(df,feat_col)
    test_dl = DataLoader(test_dset,bs=512)
    
    
    learn = Learner(None, model.cuda(), loss_func=1)
    learn.load(model_pth)
    
    learn.model.eval()
    
    preds = []
    for data in test_dl:
        inputs = data.cuda()
        outputs = learn.model(inputs) #learn.model(x).sigmoid().detach().cpu().numpy()

        preds.append(outputs.detach().cpu().numpy())

    preds = np.concatenate(preds)
    preds = pd.DataFrame(preds,index=df.index,columns=target_col)

    return preds

In [40]:
test = df.loc[split0[1]]

In [41]:
pred = predict_dl(test,
                  feat_col,
                  target_col, 
                  model,'test')
pred

Unnamed: 0,-5P,-5G,-5A,-5C,-5S,-5T,-5V,-5I,-5L,-5M,-5F,-5Y,-5W,-5H,-5K,-5R,-5Q,-5N,-5D,-5E,-5s,-5t,-5y,-4P,-4G,-4A,-4C,-4S,-4T,-4V,-4I,-4L,-4M,-4F,-4Y,-4W,-4H,-4K,-4R,-4Q,-4N,-4D,-4E,-4s,-4t,-4y,-3P,-3G,-3A,-3C,-3S,-3T,-3V,-3I,-3L,-3M,-3F,-3Y,-3W,-3H,-3K,-3R,-3Q,-3N,-3D,-3E,-3s,-3t,-3y,-2P,-2G,-2A,-2C,-2S,-2T,-2V,-2I,-2L,-2M,-2F,-2Y,-2W,-2H,-2K,-2R,-2Q,-2N,-2D,-2E,-2s,-2t,-2y,-1P,-1G,-1A,-1C,-1S,-1T,-1V,-1I,-1L,-1M,-1F,-1Y,-1W,-1H,-1K,-1R,-1Q,-1N,-1D,-1E,-1s,-1t,-1y,1P,1G,1A,1C,1S,1T,1V,1I,1L,1M,1F,1Y,1W,1H,1K,1R,1Q,1N,1D,1E,1s,1t,1y,2P,2G,2A,2C,2S,2T,2V,2I,2L,2M,2F,2Y,2W,2H,2K,2R,2Q,2N,2D,2E,2s,2t,2y,3P,3G,3A,3C,3S,3T,3V,3I,3L,3M,3F,3Y,3W,3H,3K,3R,3Q,3N,3D,3E,3s,3t,3y,4P,4G,4A,4C,4S,4T,4V,4I,4L,4M,4F,4Y,4W,4H,4K,4R,4Q,4N,4D,4E,4s,4t,4y,0s,0t,0y
3,0.827471,0.100224,0.832046,-0.875782,-0.163213,0.154136,-0.493984,-0.652184,-1.139407,-0.545682,0.170386,0.504751,-0.254571,-0.666113,-1.793947,-0.720070,-0.281525,-1.130579,-0.628133,-0.831370,1.536831,0.433035,0.392849,0.522302,-1.906116,0.269045,0.260138,-0.449420,-0.219241,-0.081720,-0.265319,-1.150107,-0.426660,0.034230,1.186783,1.155887,-0.541298,-0.147764,-0.391457,-0.532191,0.390349,-0.117091,-0.544431,-0.817431,0.209603,0.319055,-0.668265,-0.089056,-0.267722,0.873506,0.363645,0.925665,0.198308,0.376128,-0.091183,0.800630,2.594943,0.942247,0.381557,0.378796,-2.007185,-3.102536,-0.451001,0.324614,-0.403457,0.617391,0.573955,1.461275,0.632500,-1.116702,0.067198,-1.525167,0.533372,0.449234,-0.630748,0.000188,-0.406005,-0.597513,-0.477576,0.444630,1.485165,0.240802,0.838691,-2.154814,-2.045017,-0.508767,-0.634132,1.832220,-0.143416,-0.000741,0.834225,0.820550,-1.006547,-0.146693,-1.026853,1.660720,1.229229,0.859188,-0.824103,0.100888,-0.246177,1.906062,0.474571,1.582593,0.505408,0.352167,-1.356897,1.309204,1.164532,1.135649,-0.438278,0.429841,-1.104952,-1.197532,-0.390957,-1.573889,0.285766,-1.436248,0.650060,-0.180125,1.094425,0.196023,0.670384,0.781686,2.145067,0.097861,1.975262,1.960139,0.132345,0.234397,-0.731525,-1.392358,-0.154904,-1.966068,-2.782531,-0.367288,0.420465,-1.334604,-0.415886,-0.817545,0.334413,-0.257776,0.588528,0.056089,0.686094,-0.851600,-1.388452,0.962025,0.140007,1.217812,0.419813,1.218859,-1.105019,0.978566,-1.334749,-1.337203,-1.361357,-1.715217,-1.594962,0.368009,-0.917286,-0.569305,-0.249074,-0.784541,-0.572350,-0.319030,-0.301318,-0.291523,0.361332,-0.995329,0.845143,0.657074,1.600054,0.465445,0.819773,-0.739234,0.897671,-0.272619,-1.261856,-0.711332,-0.608858,-0.219403,1.353145,-0.526272,0.668316,-0.127292,-0.938279,-0.849241,0.974044,0.501423,0.257699,-0.155385,-1.242730,-0.218370,0.742504,0.296676,0.603466,-0.135024,1.358856,0.841600,0.340465,-1.716797,-0.272171,-1.086941,-0.108759,-1.107707,-0.915058,0.144688,1.703715,-1.096712
8,0.917410,0.171234,0.785013,-0.774128,-0.305260,-0.036252,-0.523487,-0.520050,-1.101476,-0.484559,-0.136079,0.453532,-0.157610,-0.722576,-1.705487,-0.679175,-0.613452,-0.849034,-0.617619,-0.829236,1.325297,0.568228,0.258706,0.530302,-1.917059,0.337471,0.461158,-0.513729,-0.009432,0.032227,-0.443830,-1.060179,-0.573893,0.035590,1.081311,1.032675,-0.778144,-0.034111,-0.088034,-0.417073,0.476150,-0.208814,-0.773442,-0.821986,0.408884,0.496548,-0.434809,-0.302523,-0.496544,0.771924,0.573002,0.862004,0.326757,0.434013,-0.269816,0.741980,2.483163,0.947341,0.286297,0.282709,-1.931927,-2.781210,-0.400135,0.336440,-0.553943,0.472467,0.712215,1.524756,0.716660,-1.030302,-0.101940,-1.563474,0.541147,0.501969,-0.687778,0.004106,-0.320514,-0.775591,-0.591284,0.515990,1.292747,0.403782,0.881161,-2.134602,-1.845125,-0.745882,-0.667191,1.685047,-0.176739,0.191540,0.742452,0.930569,-0.830662,-0.155687,-0.839995,1.647398,1.435268,0.932009,-0.948017,0.152050,-0.512855,1.643587,0.623791,1.754626,0.770352,0.206880,-1.215789,1.259650,1.313971,1.006722,-0.772838,0.465858,-0.794970,-1.441722,-0.327376,-1.471885,0.179104,-1.116795,0.668015,-0.096735,1.281591,0.349625,0.418560,0.723943,2.113839,0.183780,1.841187,1.912386,0.146610,0.131635,-0.489430,-1.257894,-0.009754,-1.919960,-2.586429,-0.311727,0.440352,-1.219443,-0.233799,-0.718143,0.118814,-0.150705,0.468099,0.415827,0.750365,-0.762541,-1.391950,0.997062,0.132728,1.180399,0.698985,0.900694,-1.148729,0.792337,-1.190787,-1.193467,-1.286419,-1.784168,-1.594329,0.269658,-0.706102,-0.771091,-0.014769,-0.608701,-0.543522,-0.351823,-0.201778,-0.185016,0.034217,-1.021030,0.948197,0.687479,1.681320,0.351855,0.779538,-0.798065,0.688631,-0.278176,-1.124563,-0.408594,-0.707703,-0.373782,1.323823,-0.405751,0.642940,-0.336016,-0.823702,-0.762177,0.757126,0.727935,0.251524,-0.214944,-1.259664,-0.309738,0.627646,0.497230,0.705490,-0.288012,1.101066,0.818683,0.443231,-1.763009,-0.266467,-1.181371,-0.017256,-0.936991,-0.793185,0.333979,1.577998,-1.189008
10,0.837395,0.109697,0.832660,-0.870176,-0.180581,0.144930,-0.501664,-0.639778,-1.142484,-0.551874,0.151722,0.500778,-0.250504,-0.668351,-1.793540,-0.717512,-0.301737,-1.106298,-0.622549,-0.831257,1.534428,0.446405,0.384790,0.528305,-1.910708,0.277465,0.273321,-0.455295,-0.206712,-0.077194,-0.282748,-1.152125,-0.438941,0.030492,1.181790,1.150298,-0.555282,-0.143044,-0.377379,-0.521493,0.399298,-0.122062,-0.554159,-0.819373,0.228382,0.329104,-0.644800,-0.104745,-0.276984,0.866417,0.376644,0.924284,0.213041,0.387116,-0.109702,0.801984,2.594317,0.938948,0.380097,0.372417,-2.003545,-3.086225,-0.451124,0.331583,-0.413769,0.611423,0.585708,1.464353,0.644266,-1.113469,0.055849,-1.528910,0.539915,0.458593,-0.631647,-0.000287,-0.401935,-0.607183,-0.481346,0.456771,1.477684,0.245798,0.845344,-2.158220,-2.034769,-0.522975,-0.642835,1.828530,-0.150154,0.011628,0.830574,0.830671,-0.998990,-0.146710,-1.013373,1.662167,1.246780,0.865196,-0.829360,0.102111,-0.261659,1.887859,0.491953,1.590855,0.519799,0.346759,-1.349203,1.298726,1.178621,1.132242,-0.463632,0.429732,-1.092150,-1.221142,-0.389510,-1.571760,0.278079,-1.421545,0.654995,-0.173786,1.097642,0.208115,0.653605,0.791366,2.145888,0.107205,1.972795,1.967754,0.133712,0.220518,-0.719943,-1.387880,-0.145486,-1.966254,-2.776184,-0.357707,0.420155,-1.339678,-0.407719,-0.815095,0.320433,-0.256512,0.580443,0.077515,0.691271,-0.851165,-1.388742,0.972441,0.139406,1.215708,0.434404,1.205417,-1.109673,0.969308,-1.328043,-1.331450,-1.353500,-1.721882,-1.599746,0.369419,-0.908263,-0.583409,-0.224611,-0.774605,-0.570776,-0.326263,-0.293524,-0.288654,0.341957,-0.999525,0.852286,0.663934,1.609160,0.456745,0.825500,-0.742155,0.896928,-0.279878,-1.254960,-0.689967,-0.616655,-0.226224,1.354397,-0.521402,0.670991,-0.148360,-0.929999,-0.850310,0.961717,0.518277,0.263442,-0.164127,-1.248991,-0.224318,0.731391,0.316439,0.615672,-0.148183,1.353209,0.843198,0.354706,-1.723555,-0.269565,-1.092407,-0.105944,-1.089590,-0.905481,0.162645,1.700939,-1.109121
19,0.781596,0.054428,0.886624,-1.019367,-0.026675,0.245834,-0.415021,-0.735891,-1.145012,-0.566665,0.411863,0.574012,-0.374369,-0.631664,-1.854622,-0.716872,-0.096983,-1.327580,-0.620644,-0.825760,1.675245,0.339548,0.467944,0.451587,-1.878664,0.224201,0.061338,-0.406528,-0.328586,-0.100980,-0.146840,-1.148672,-0.323020,0.025484,1.216742,1.189036,-0.352998,-0.173629,-0.544733,-0.595421,0.373671,-0.088120,-0.404013,-0.829182,0.029847,0.155994,-0.801714,0.054961,-0.098439,0.903357,0.248087,1.015439,0.119524,0.294137,0.071563,0.855235,2.605291,0.927736,0.496522,0.432004,-2.082788,-3.447055,-0.448698,0.328277,-0.277545,0.708672,0.488267,1.354602,0.506899,-1.202510,0.211965,-1.520520,0.463711,0.379593,-0.555098,0.053422,-0.462469,-0.450279,-0.396622,0.425477,1.636369,0.182411,0.838295,-2.158264,-2.224430,-0.385073,-0.570999,1.920744,-0.132848,-0.143569,0.833639,0.680066,-1.108618,-0.076675,-1.155808,1.666834,1.063978,0.841704,-0.793171,0.124623,-0.043953,2.057241,0.364838,1.394512,0.317737,0.453909,-1.453573,1.355068,1.030030,1.227987,-0.123894,0.402724,-1.298145,-0.952391,-0.406544,-1.720648,0.356935,-1.694481,0.614875,-0.241664,0.999034,0.101253,0.854732,0.846086,2.136112,0.035509,2.088189,1.922130,0.126163,0.307088,-0.869824,-1.504479,-0.239283,-1.973629,-2.910555,-0.379102,0.384176,-1.440574,-0.547820,-0.893037,0.515938,-0.321208,0.708824,-0.219489,0.671979,-0.955390,-1.424811,0.890984,0.133275,1.251851,0.200973,1.460850,-1.072274,1.118181,-1.424450,-1.432409,-1.434649,-1.693852,-1.590635,0.398966,-1.043202,-0.442988,-0.422975,-0.857704,-0.582256,-0.268851,-0.305556,-0.369747,0.599190,-0.958157,0.804685,0.616220,1.512328,0.540506,0.841114,-0.683778,1.090572,-0.269047,-1.389140,-0.953336,-0.503750,-0.122819,1.367146,-0.584153,0.665373,0.062932,-0.983125,-0.920183,1.157806,0.317912,0.243558,-0.128526,-1.204561,-0.181557,0.803107,0.091521,0.471424,-0.031081,1.550624,0.846450,0.293305,-1.660746,-0.262077,-0.982250,-0.130795,-1.255570,-0.984373,-0.038575,1.802894,-0.973010
24,0.786290,0.079069,0.880414,-0.981494,-0.038584,0.230998,-0.448706,-0.692160,-1.134339,-0.558241,0.397392,0.563033,-0.363803,-0.645807,-1.846820,-0.700025,-0.118564,-1.274372,-0.601713,-0.807587,1.648476,0.357077,0.448028,0.461057,-1.872094,0.237823,0.096480,-0.420398,-0.307801,-0.088697,-0.160272,-1.150119,-0.345459,0.019907,1.206092,1.175314,-0.374911,-0.167590,-0.533220,-0.586528,0.391906,-0.093227,-0.418521,-0.830271,0.042791,0.180905,-0.773885,0.040752,-0.106419,0.883396,0.268997,1.007495,0.121850,0.308463,0.055195,0.858030,2.591177,0.925743,0.508263,0.433671,-2.070256,-3.375800,-0.450623,0.316248,-0.286478,0.691198,0.480683,1.351313,0.522532,-1.194276,0.200118,-1.516695,0.493254,0.390673,-0.541645,0.042630,-0.433976,-0.454379,-0.409099,0.430472,1.601456,0.181383,0.830983,-2.139535,-2.181069,-0.409251,-0.565177,1.900097,-0.133536,-0.130521,0.821082,0.711679,-1.089747,-0.074010,-1.138318,1.658434,1.066018,0.845442,-0.787894,0.131238,-0.058087,2.038095,0.378360,1.402364,0.341870,0.442856,-1.428763,1.334331,1.037531,1.213391,-0.144534,0.421373,-1.286405,-0.970518,-0.407965,-1.671879,0.345138,-1.652879,0.612410,-0.211857,0.989774,0.090873,0.821209,0.822212,2.119716,0.044015,2.056953,1.912315,0.124411,0.290797,-0.845661,-1.492286,-0.226419,-1.973430,-2.880046,-0.368402,0.403572,-1.430287,-0.548127,-0.880053,0.469158,-0.328019,0.686918,-0.187181,0.674285,-0.944263,-1.403961,0.894274,0.120857,1.223115,0.231179,1.422888,-1.071916,1.089855,-1.423921,-1.405556,-1.432558,-1.694204,-1.585991,0.395474,-1.001245,-0.472834,-0.398623,-0.836932,-0.580620,-0.272742,-0.293357,-0.356572,0.577947,-0.953720,0.788483,0.622082,1.512632,0.533714,0.830057,-0.697636,1.065160,-0.262681,-1.366525,-0.904157,-0.526866,-0.139730,1.364432,-0.561192,0.657009,0.038055,-0.958606,-0.907986,1.131933,0.334181,0.249553,-0.134938,-1.188955,-0.191562,0.787495,0.116197,0.479489,-0.046379,1.518119,0.835072,0.291257,-1.674454,-0.255510,-0.991072,-0.112540,-1.216258,-0.985859,-0.013125,1.791357,-0.978988
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
359,0.852422,0.135191,0.803335,-0.682628,-0.197535,0.111441,-0.534158,-0.505383,-1.162610,-0.538424,-0.021176,0.632151,-0.025112,-0.665469,-1.710826,-0.776242,-0.748986,-1.080172,-0.753498,-1.002758,1.160442,0.550065,0.357007,0.401259,-2.031294,0.323325,0.388282,-0.475727,-0.005309,0.068387,-0.466699,-1.210684,-0.459825,0.146122,1.294962,1.168700,-0.753919,-0.047600,-0.238652,-0.610689,0.442146,-0.204308,-0.976595,-0.820102,0.395491,0.482606,-0.600670,-0.329664,-0.522160,0.889844,0.524637,0.902219,0.242944,0.344369,-0.115987,0.732324,2.553711,1.148839,0.464048,0.280631,-1.980339,-2.848927,-0.505171,0.322223,-0.584744,0.362167,0.647299,1.535061,0.839859,-1.160352,0.008072,-1.660247,0.604313,0.489808,-0.741202,0.006355,-0.421018,-0.814922,-0.566360,0.559379,1.306701,0.443016,0.864872,-2.187451,-2.119461,-0.675246,-0.562467,1.644104,-0.176978,0.100741,0.715397,0.950394,-1.031454,-0.258865,-1.033383,1.611082,1.424277,0.961288,-0.975064,0.044345,-0.491647,1.827934,0.591021,1.873002,0.812739,0.245653,-1.284134,1.459863,1.388409,1.154546,-0.747415,0.378350,-0.958165,-1.388069,-0.316556,-1.421812,0.151718,-1.290187,0.739967,-0.089276,1.370041,0.249051,0.507383,0.729861,2.187623,0.181003,1.977932,2.022047,0.264086,0.244274,-0.468108,-1.399503,-0.107290,-2.080383,-2.838415,-0.406632,0.454743,-1.288757,-0.370415,-0.708696,0.102508,-0.075795,0.616586,0.344751,0.662345,-0.758140,-1.575279,1.061059,0.138530,1.270358,0.752367,1.059788,-1.100806,0.893171,-1.255185,-1.158705,-1.399722,-1.987958,-1.636354,0.228773,-0.869164,-0.753039,-0.128588,-0.742070,-0.508778,-0.252595,-0.170855,-0.253869,0.039418,-1.107684,0.952661,0.591774,1.760877,0.536997,0.963283,-0.797976,0.708162,-0.152582,-1.038628,-0.590417,-0.794161,-0.273169,1.187026,-0.527838,0.732147,-0.343438,-0.997734,-0.721266,0.939762,0.757312,0.115442,-0.188129,-1.381625,-0.186085,0.745146,0.534328,0.809492,-0.158795,1.185196,0.738872,0.452261,-1.792858,-0.426966,-1.256581,-0.159333,-1.088121,-0.948825,0.350924,1.641778,-1.249069
361,0.817391,0.066909,0.861728,-0.820828,-0.101933,0.223830,-0.550699,-0.614653,-1.208950,-0.595266,0.228076,0.639920,-0.157404,-0.638020,-1.808211,-0.756362,-0.398338,-1.223009,-0.697684,-0.968820,1.449097,0.473187,0.424028,0.475327,-2.021605,0.297017,0.266568,-0.453896,-0.175089,-0.027513,-0.297073,-1.238146,-0.377312,0.102429,1.313934,1.241425,-0.588323,-0.145568,-0.463696,-0.666938,0.386830,-0.134614,-0.758233,-0.821851,0.185983,0.341184,-0.751713,-0.169910,-0.330749,0.927613,0.389297,0.988620,0.155020,0.332752,0.021174,0.838551,2.690545,1.085305,0.538575,0.394114,-2.086341,-3.145844,-0.530758,0.315191,-0.448775,0.533272,0.559770,1.475773,0.770389,-1.240129,0.126447,-1.617360,0.597686,0.470475,-0.670058,0.023400,-0.437610,-0.676231,-0.452324,0.523027,1.514605,0.344220,0.889705,-2.213305,-2.251495,-0.506792,-0.616503,1.806869,-0.136031,-0.049555,0.759245,0.867065,-1.114460,-0.200896,-1.122525,1.682196,1.230175,0.931741,-0.889741,0.066583,-0.277517,2.005597,0.476001,1.742925,0.565623,0.373920,-1.420270,1.415827,1.261921,1.241692,-0.497678,0.411198,-1.188702,-1.208542,-0.370962,-1.572881,0.244546,-1.540803,0.739303,-0.129672,1.229159,0.129319,0.737835,0.824232,2.204368,0.126864,2.073917,2.056757,0.203360,0.289404,-0.674757,-1.511156,-0.245752,-2.143941,-3.013253,-0.439010,0.455978,-1.437115,-0.476252,-0.867653,0.287618,-0.223040,0.674525,0.074388,0.617374,-0.837220,-1.536570,1.013208,0.139984,1.303378,0.495366,1.281375,-1.111345,1.060646,-1.406862,-1.335690,-1.420924,-1.903942,-1.628278,0.356124,-1.046179,-0.562881,-0.304964,-0.853850,-0.522766,-0.279432,-0.254845,-0.328819,0.331825,-1.123803,0.887757,0.614448,1.672400,0.580790,0.964061,-0.730788,0.876058,-0.148131,-1.221625,-0.813070,-0.681926,-0.200397,1.266091,-0.545727,0.726326,-0.127788,-1.036352,-0.831623,1.073829,0.546031,0.226980,-0.141519,-1.352796,-0.140593,0.781673,0.364335,0.710168,-0.061953,1.407219,0.807473,0.365853,-1.783174,-0.357436,-1.184780,-0.175389,-1.161159,-0.956531,0.197732,1.771031,-1.170678
366,0.736367,-0.009696,0.969619,-1.127532,0.144750,0.292795,-0.344816,-0.781180,-1.181761,-0.555607,0.614502,0.655654,-0.490570,-0.556051,-1.909676,-0.697448,0.010961,-1.529093,-0.665679,-0.885523,1.743442,0.301108,0.559450,0.359698,-1.865560,0.225958,-0.073739,-0.328535,-0.424078,-0.050327,-0.072884,-1.164826,-0.198429,0.007245,1.300829,1.195708,-0.232114,-0.179775,-0.638567,-0.648339,0.314607,-0.098741,-0.331572,-0.839038,-0.144630,0.014305,-0.934783,0.157165,-0.019406,0.962281,0.193096,1.133443,0.046364,0.221449,0.229887,0.907029,2.635381,0.905171,0.632662,0.476889,-2.158721,-3.740595,-0.443866,0.284959,-0.223193,0.767338,0.484332,1.281591,0.391474,-1.312972,0.355046,-1.488309,0.388440,0.346864,-0.470980,0.133614,-0.532521,-0.332068,-0.307503,0.398594,1.762086,0.161883,0.861015,-2.121942,-2.416008,-0.306379,-0.482739,1.975375,-0.090602,-0.277195,0.757733,0.574751,-1.189831,-0.019551,-1.225224,1.658932,0.909126,0.879132,-0.793751,0.146416,0.077841,2.159410,0.267883,1.318121,0.158684,0.499024,-1.488347,1.447917,0.917704,1.320736,0.057925,0.380927,-1.443924,-0.727737,-0.378857,-1.840289,0.381628,-1.873738,0.618523,-0.257888,0.983127,0.014522,1.036318,0.927514,2.119387,0.038633,2.196174,1.898270,0.143137,0.371436,-0.925582,-1.606647,-0.376975,-2.012934,-3.053819,-0.435470,0.348005,-1.529234,-0.666514,-0.988957,0.652559,-0.380255,0.837427,-0.444440,0.626248,-1.035783,-1.540657,0.829535,0.134082,1.359996,0.086900,1.667849,-1.041736,1.303053,-1.536747,-1.482981,-1.517062,-1.724163,-1.569078,0.398250,-1.165087,-0.328690,-0.578908,-0.909962,-0.558582,-0.231092,-0.259670,-0.405262,0.770781,-0.956190,0.784130,0.556924,1.422143,0.587300,0.939328,-0.643971,1.251529,-0.208695,-1.483723,-1.188880,-0.426775,-0.085915,1.312669,-0.649168,0.642353,0.186576,-1.034911,-0.980427,1.331193,0.181913,0.198342,-0.116720,-1.205564,-0.123322,0.869537,-0.066354,0.388415,0.061960,1.728792,0.860009,0.278499,-1.639867,-0.325064,-0.934766,-0.202117,-1.406237,-1.035918,-0.152255,1.873838,-0.880204
367,0.724401,-0.019061,0.971441,-1.139351,0.158653,0.296373,-0.329004,-0.796939,-1.178109,-0.553247,0.629083,0.662251,-0.506471,-0.539309,-1.909250,-0.695755,0.027085,-1.557588,-0.666611,-0.887311,1.757227,0.295163,0.574550,0.343939,-1.853038,0.221966,-0.098035,-0.312960,-0.433927,-0.050586,-0.063850,-1.160515,-0.186830,0.006605,1.306884,1.193217,-0.210796,-0.178851,-0.643341,-0.651306,0.302899,-0.101557,-0.308720,-0.836970,-0.159593,-0.014414,-0.949915,0.176432,-0.004292,0.958615,0.181704,1.135270,0.040149,0.213926,0.250159,0.906866,2.618574,0.895872,0.638841,0.483742,-2.153453,-3.772979,-0.437052,0.280960,-0.207360,0.772313,0.490956,1.270018,0.367565,-1.315215,0.374075,-1.478493,0.366801,0.334773,-0.462913,0.137689,-0.541791,-0.312256,-0.291117,0.392223,1.779289,0.149574,0.856331,-2.111991,-2.431390,-0.285890,-0.475982,1.977784,-0.086828,-0.292871,0.759561,0.550275,-1.186828,-0.001344,-1.226874,1.650718,0.891176,0.877768,-0.788195,0.148672,0.095447,2.158775,0.255060,1.293651,0.131154,0.495155,-1.487620,1.452766,0.900406,1.327715,0.085691,0.372670,-1.454059,-0.701512,-0.379369,-1.845481,0.393916,-1.891263,0.620459,-0.267988,0.977500,0.010319,1.046921,0.935379,2.112456,0.044012,2.195603,1.885747,0.138098,0.379871,-0.937802,-1.609828,-0.394650,-1.996124,-3.050131,-0.443390,0.341929,-1.532011,-0.674782,-0.997301,0.681481,-0.375820,0.849379,-0.472301,0.617798,-1.040801,-1.539947,0.816455,0.132350,1.362182,0.063987,1.689147,-1.030874,1.322626,-1.540051,-1.482728,-1.514823,-1.709387,-1.563476,0.396952,-1.186564,-0.307285,-0.598137,-0.912492,-0.560189,-0.227781,-0.262265,-0.414221,0.786078,-0.944608,0.786475,0.549084,1.408655,0.586735,0.945035,-0.624524,1.278060,-0.202815,-1.502831,-1.215252,-0.413168,-0.073636,1.305231,-0.663560,0.639237,0.215288,-1.042258,-0.980603,1.350477,0.159159,0.197200,-0.115719,-1.201598,-0.107219,0.874845,-0.088925,0.373468,0.074698,1.746727,0.864847,0.270554,-1.626873,-0.326401,-0.907621,-0.213689,-1.427742,-1.043920,-0.174457,1.869718,-0.861120


In [42]:
score_each(test[target_col],pred)

overall MSE: 2.1770
Average Pearson: 0.2149 


(2.176958126296138,
 0.21488261204649625,
       Pearson
 3   -0.442855
 8   -0.490345
 10  -0.401885
 19  -0.428557
 24  -0.383956
 ..        ...
 359 -0.127000
 361  0.005761
 366  0.095977
 367  0.335805
 373 -0.230842
 
 [78 rows x 1 columns])

In [43]:
#| hide
import nbdev; nbdev.nbdev_export()