In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
import pandas as pd
import torch
import argparse
from models.data_process import get_datatensor_partitions, prepare_nonproto_features, generate_partition_datatensor,get_data_ready
from models.dataset import ProtospacerDataset, ProtospacerExtendedDataset
from models.trainval_workflow import run_trainevaltest_workflow
from models.trainval_workflow import run_inference
from models.hyperparam import build_config_map
from src.utils import create_directory, one_hot_encode, get_device, ReaderWriter 
from src.utils import print_eval_results, plot_y_distrib_acrossfolds, compute_eval_results_df
import matplotlib.pyplot as plt

In [3]:
cmd_opt = argparse.ArgumentParser(description='Argparser for data')
cmd_opt.add_argument('-model_name',  type=str, help = 'name of the model')
cmd_opt.add_argument('-exp_name',  type=str, help = 'name of the experiment')

cmd_opt.add_argument('-data_dir',  type=str,default = './data/', help = 'directory of the data')
cmd_opt.add_argument('-target_dir',  type=str, default='processed',  help = 'folder name to save the processed data')
cmd_opt.add_argument('-working_dir',  type=str, default='./', help = 'the main working directory')
cmd_opt.add_argument('-output_path', type=str, help='path to save the trained model')
cmd_opt.add_argument('-random_seed', type=int,default=42)
cmd_opt.add_argument('-epoch_num', type=int, default =200, help='number of training epochs')
args, _ = cmd_opt.parse_known_args()

### Functions to make data ready and get predefined hyperparams for a given model and experiment choice

In [4]:
# predefined hyperparameters depending on the chosen model and experiment
def get_hyperparam_config(args):
    "return predefined hyperparameters for each model"
    to_gpu = True
    gpu_index=0
    optim_tup = None
    
    if args.model_name == 'FFN':
        batch_size = 100
        num_epochs = 300
        h = [60,10]
        l2_reg =0.1
        model_config_tup = (h, l2_reg, batch_size, num_epochs)
        
        if args.exp_name == 'protospacer_extended':
            mlpembedder_tup = (10, 16, 2, torch.nn.ReLU, 0.1, 1)
            xproto_inputsize = 20 + 10
        else:
            mlpembedder_tup = None
            xproto_inputsize = 20
        
        loss_func_name = 'MSEloss'
        perfmetric_name = 'pearson'
        
    if args.model_name == 'CNN':
        k = 2
        l2_reg = 0.5
        batch_size = 100
        num_epochs = 300
        model_config_tup = (k, l2_reg, batch_size, num_epochs)


        # input_dim, embed_dim, mlp_embed_factor, nonlin_func, p_dropout, num_encoder_units
        if args.exp_name == 'protospacer_extended':
            mlpembedder_tup = (10, 16, 2, torch.nn.ReLU, 0.1, 1)
            xproto_inputsize = 20 + 10
        else:
            mlpembedder_tup = None
            xproto_inputsize = 20

        loss_func_name = 'MSEloss'
        # loss_func_name = 'SmoothL1loss'
        perfmetric_name = 'spearman'

    elif args.model_name == 'RNN':
        embed_dim = 64
        hidden_dim = 64
        z_dim = 32
        num_hidden_layers =2
        bidirection = True
        p_dropout = 0.1     
        rnn_class = torch.nn.GRU
        nonlin_func = torch.nn.ReLU
        pooling_mode = 'none'
        l2_reg = 1e-5
        batch_size = 1500
        num_epochs = 500

        model_config_tup = (embed_dim, hidden_dim, z_dim, num_hidden_layers, bidirection, 
                   p_dropout, rnn_class, nonlin_func, pooling_mode, l2_reg, batch_size, num_epochs)

        # input_dim, embed_dim, mlp_embed_factor, nonlin_func, p_dropout, num_encoder_units
        if args.exp_name == 'protospacer_extended':
            mlpembedder_tup = (10, 16, 2, torch.nn.ReLU, 0.1, 1)
            xproto_inputsize = 20 + 10
        else:
            mlpembedder_tup = None
            xproto_inputsize = 20

        loss_func_name = 'SmoothL1loss'
        perfmetric_name = 'pearson'

    elif args.model_name == 'Transformer':
        embed_dim = 128
        num_attn_heads = 4
        num_trf_units = 1
        pdropout = 0.1
        activ_func = torch.nn.GELU
        multp_factor = 2
        multihead_type = 'Wide'
        pos_embed_concat_opt = 'stack'
        pooling_opt = 'none'
        weight_decay = 1e-8
        batch_size = 1000
        num_epochs = 1000


        model_config_tup = (embed_dim, num_attn_heads, num_trf_units,
                            pdropout, activ_func, multp_factor, multihead_type, 
                            pos_embed_concat_opt, pooling_opt, weight_decay, batch_size, num_epochs)

        # input_dim, embed_dim, mlp_embed_factor, nonlin_func, p_dropout, num_encoder_units
        if args.exp_name == 'protospacer_extended':
            mlpembedder_tup = (10, 16, 2, torch.nn.GELU, 0.1, 1)
            xproto_inputsize = 20 + 10
        else:
            mlpembedder_tup = None
            xproto_inputsize = 20 

        loss_func_name = 'SmoothL1loss'
        perfmetric_name = 'pearson'


    mconfig, options = build_config_map(args.model_name, 
                                        optim_tup, 
                                        model_config_tup, 
                                        mlpembedder_tup, 
                                        loss_func = loss_func_name)



    options['input_size'] = xproto_inputsize
    options['loss_func'] = loss_func_name # to refactor
    options['model_name'] = args.model_name
    options['perfmetric_name'] = perfmetric_name
    return mconfig, options

#import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


### Run training workflow for nerural network models

In [5]:
dsettypes = ['train', 'validation','test']
gpu_index = 0
res_desc = {}
version=2
for model_name in [ 'RNN']:  #'FFN','CNN', 'RNN',Transformer
    print(model_name)
    args.model_name =  model_name # {'RNN','CNN', 'Transformer'}
    res_desc[model_name] = {}
    for exp_name in ['protospacer']: #,'protospacer_extended']:
        args.exp_name = exp_name
        model_path = os.path.join(args.working_dir, 
                                  'output', 
                                  f'{model_name}_v{version}',
                                  exp_name)
        dpartitions, datatensor_partitions = get_data_ready(args, 
                                                            normalize_opt='max',
                                                            train_size=0.9, 
                                                            fdtype=torch.float32,
                                                            plot_y_distrib=False)
        mconfig, options = get_hyperparam_config(args)
        print(options)
        
#         options['num_epochs'] = 10 # use this if you want to test a whole workflow run for all models using 10 epochs
        
        perfmetric_name = options['perfmetric_name']
        train_val_path = os.path.join(model_path, 'train_val')
        test_path = os.path.join(model_path, 'test')
        
        print(f'Running model: {model_name}, exp_name: {exp_name}, saved at {train_val_path}')
        perfmetric_run_map, score_run_dict = run_trainevaltest_workflow(datatensor_partitions, 
                                                                        (mconfig, options), 
                                                                        train_val_path,
                                                                        dsettypes,
                                                                        perfmetric_name,
                                                                        gpu_index, 
                                                                        to_gpu=True)
        print('='*15)
        res_desc[model_name][exp_name] = compute_eval_results_df(train_val_path, len(dpartitions)) 
        

RNN
--- max normalization ---
{'run_num': -1, 'num_epochs': 500, 'weight_decay': 1e-05, 'fdtype': torch.float32, 'to_gpu': True, 'loss_func': 'SmoothL1loss', 'input_size': 20, 'model_name': 'RNN', 'perfmetric_name': 'pearson'}
Running model: RNN, exp_name: protospacer, saved at ./output/RNN_v2/protospacer/train_val
cuda:0
validation
test
number of epochs 500




weight_decay 1e-05
Epoch 1/500, Training Loss: 8.0201, Validation Loss: 7.5960
Epoch 1/500, best pearson corr. on the validation set so far: 0.1687
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 2/500, Training Loss: 7.4207, Validation Loss: 6.5753
Epoch 2/500, best pearson corr. on the validation set so far: 0.2184
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 3/500, Training Loss: 6.2234, Validation Loss: 5.5518
Epoch 3/500, best pearson corr. on the validation set so far: 0.2338
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 4/500, Training Loss: 6.0507, Validation Loss: 5.8290
Epoch 4/500, best pearson corr. on the validation set so far: 0.2567
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 12/500, Training Loss: 5.5505, Validation Loss: 5.3204
Epoch 12/500, best pearson corr. on the validation set so far: 0.2614
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 13/500, Training Loss: 5.5273, Validation Loss: 5.2819
Epoch 13/500, best pearson corr. on the validation set so far: 0.2796
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 14/500, Training Loss: 5.4914, Validation

Epoch 79/500, Training Loss: 2.7911, Validation Loss: 3.3178
Epoch 79/500, best pearson corr. on the validation set so far: 0.7846
~~~~~~~~~~~~~~~~~~~~~~~~~
validation
test
number of epochs 500
weight_decay 1e-05
Epoch 2/500, Training Loss: 7.5046, Validation Loss: 7.0090
Epoch 2/500, best pearson corr. on the validation set so far: 0.0095
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 3/500, Training Loss: 6.3837, Validation Loss: 5.7730
Epoch 3/500, best pearson corr. on the validation set so far: 0.1059
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 4/500, Training Loss: 5.9360, Validation Loss: 6.0606
Epoch 4/500, best pearson corr. on the validation set so far: 0.2710
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 5/500, Training Loss: 5.8899, Validation Loss: 5.7572
Epoch 5/500, best pearson corr. on the validation set so far: 0.3335
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 6/500, Training Loss: 5.8108, Validation Loss: 5.8366
Epoch 6/500, best pearson corr. on the validation set so far: 0.3348
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 7/500,

Epoch 70/500, Training Loss: 2.8241, Validation Loss: 3.3187
Epoch 70/500, best pearson corr. on the validation set so far: 0.7960
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 81/500, Training Loss: 2.6774, Validation Loss: 3.3036
Epoch 81/500, best pearson corr. on the validation set so far: 0.7962
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 85/500, Training Loss: 2.7785, Validation Loss: 3.2852
Epoch 85/500, best pearson corr. on the validation set so far: 0.7974
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 86/500, Training Loss: 2.6853, Validation Loss: 3.4331
Epoch 86/500, best pearson corr. on the validation set so far: 0.7995
~~~~~~~~~~~~~~~~~~~~~~~~~
validation
test
number of epochs 500
weight_decay 1e-05
Epoch 1/500, Training Loss: 8.0868, Validation Loss: 7.8358
Epoch 1/500, best pearson corr. on the validation set so far: 0.0490
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 2/500, Training Loss: 7.6941, Validation Loss: 7.1889
Epoch 2/500, best pearson corr. on the validation set so far: 0.0828
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 

Epoch 64/500, Training Loss: 3.1189, Validation Loss: 3.4079
Epoch 64/500, best pearson corr. on the validation set so far: 0.7795
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 66/500, Training Loss: 3.0967, Validation Loss: 3.3836
Epoch 66/500, best pearson corr. on the validation set so far: 0.7809
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 67/500, Training Loss: 3.0877, Validation Loss: 3.3554
Epoch 67/500, best pearson corr. on the validation set so far: 0.7848
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 69/500, Training Loss: 2.9996, Validation Loss: 3.3420
Epoch 69/500, best pearson corr. on the validation set so far: 0.7866
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 72/500, Training Loss: 2.9693, Validation Loss: 3.3227
Epoch 72/500, best pearson corr. on the validation set so far: 0.7890
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 74/500, Training Loss: 2.9547, Validation Loss: 3.2946
Epoch 74/500, best pearson corr. on the validation set so far: 0.7912
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 75/500, Training Loss: 2.9278, Validation Loss: 3.33

Epoch 36/500, Training Loss: 3.7662, Validation Loss: 4.1761
Epoch 36/500, best pearson corr. on the validation set so far: 0.7329
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 37/500, Training Loss: 3.7259, Validation Loss: 4.0290
Epoch 37/500, best pearson corr. on the validation set so far: 0.7391
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 38/500, Training Loss: 3.6267, Validation Loss: 3.9730
Epoch 38/500, best pearson corr. on the validation set so far: 0.7443
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 39/500, Training Loss: 3.5705, Validation Loss: 3.9622
Epoch 39/500, best pearson corr. on the validation set so far: 0.7499
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 41/500, Training Loss: 3.5321, Validation Loss: 3.9413
Epoch 41/500, best pearson corr. on the validation set so far: 0.7527
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 42/500, Training Loss: 3.5169, Validation Loss: 3.8807
Epoch 42/500, best pearson corr. on the validation set so far: 0.7528
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 43/500, Training Loss: 3.4942, Validation Loss: 3.85

Epoch 27/500, Training Loss: 4.3104, Validation Loss: 4.1954
Epoch 27/500, best pearson corr. on the validation set so far: 0.6856
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 28/500, Training Loss: 4.2305, Validation Loss: 4.1453
Epoch 28/500, best pearson corr. on the validation set so far: 0.6941
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 29/500, Training Loss: 4.1663, Validation Loss: 4.0922
Epoch 29/500, best pearson corr. on the validation set so far: 0.7001
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 30/500, Training Loss: 4.1330, Validation Loss: 4.0743
Epoch 30/500, best pearson corr. on the validation set so far: 0.7020
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 31/500, Training Loss: 4.0990, Validation Loss: 4.0414
Epoch 31/500, best pearson corr. on the validation set so far: 0.7065
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 32/500, Training Loss: 4.0742, Validation Loss: 4.0289
Epoch 32/500, best pearson corr. on the validation set so far: 0.7099
~~~~~~~~~~~~~~~~~~~~~~~~~
Epoch 33/500, Training Loss: 4.0296, Validation Loss: 4.03

Epoch 131/500, Training Loss: 2.0882, Validation Loss: 3.1621
Epoch 131/500, best pearson corr. on the validation set so far: 0.8177
~~~~~~~~~~~~~~~~~~~~~~~~~
run_name: run_0
run_name: run_1
run_name: run_2
run_name: run_3
run_name: run_4


In [6]:
res_desc

{'RNN': {'protospacer':              run_0     run_1     run_2     run_3     run_4      mean   
  spearman  0.765320  0.786523  0.768976  0.766275  0.758588  0.769136  \
  pearson   0.810613  0.815035  0.795311  0.797432  0.796956  0.803069   
  MAE       3.832806  3.877604  3.869510  3.945184  3.703842  3.845789   
  
              median    stddev  
  spearman  0.766275  0.010444  
  pearson   0.797432  0.009075  
  MAE       3.869510  0.089123  }}