In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
import pandas as pd
import numpy as np
import os
import random
from collections import defaultdict, Counter
from itertools import combinations
import json
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold
import optuna
from torchmetrics.classification import F1Score
import pickle

import sys
### import Dataset prepartion and model training classes from BS_LS_scripts folder
sys.path.insert(1, '/home/wangc90/circRNA/circRNA_Data/BS_LS_scripts/')
from BS_LS_DataSet_2 import BS_LS_DataSet_Prep, RCM_Score
from BS_LS_Training_Base_models_1 import Objective, Objective_CV


In [2]:
class RCM_optuna_flanking(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the flanking introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_flanking, self).__init__()
        
        # convlayer 1
        self.out_channel1 = trial.suggest_categorical('flanking_out_channel1', [128, 256, 512])
#         self.out_channel1 = 128

#         kernel_size1 = 5

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=5, stride=5, padding=0)
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
        
        self.out_channel2 = trial.suggest_categorical('flanking_out_channel2', [128, 256, 512])
#         self.out_channel2 = 32
        
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
        
        self.conv2_out_dim = 5
    

    def forward(self, x):
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
        return out

    
class RCM_optuna_upper(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the upper introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_upper, self).__init__()
        
        # convlayer 1
        self.out_channel1 = trial.suggest_categorical('upper_out_channel1', [128, 256, 512])
#         self.out_channel1 = 512

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
        
        self.out_channel2 = trial.suggest_categorical('upper_out_channel2', [128, 256, 512])
#         self.out_channel2 = 64
        
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
        self.conv2_out_dim = 5
    

    def forward(self, x):
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
        return out
    
    
class RCM_optuna_lower(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the lower introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_lower, self).__init__()
        
        # convlayer 1
        self.out_channel1 = trial.suggest_categorical('lower_out_channel1', [128, 256, 512])
#         self.out_channel1 = 512

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
        self.out_channel2 = trial.suggest_categorical('lower_out_channel2', [128, 256, 512])
#         self.out_channel2 = 512
        
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
        self.conv2_out_dim = 5
    

    def forward(self, x):
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
        return out
    
    

class RCM_optuna_concate(nn.Module):
    ''''
        
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_concate, self).__init__()

        ### cnn for the flanking rcm scores
        self.cnn_flanking = RCM_optuna_flanking(trial)

        self.flanking_out_dim = self.cnn_flanking.conv2_out_dim
        self.flanking_out_channel = self.cnn_flanking.out_channel2
#         print(f'flanking out dim: {self.flanking_out_dim}, flanking out channel {self.flanking_out_channel}')
        
        ### cnn for the upper rcm scores
        self.cnn_upper = RCM_optuna_upper(trial)

        self.upper_out_dim = self.cnn_upper.conv2_out_dim
        self.upper_out_channel = self.cnn_upper.out_channel2
#         print(f'upper_out_dim: {self.upper_out_dim}, upper_out_channel {self.upper_out_channel}')
        
        ### cnn for the lower rcm scores
        self.cnn_lower = RCM_optuna_lower(trial)

        self.lower_out_dim = self.cnn_lower.conv2_out_dim
        self.lower_out_channel = self.cnn_lower.out_channel2
#         print(f'lower_out_dim: {self.lower_out_dim}, lower_out_channel {self.lower_out_channel}')
        

        self.fc1_input_dim = self.flanking_out_dim * self.flanking_out_channel + \
                             self.upper_out_dim * self.upper_out_channel + \
                             self.lower_out_dim * self.lower_out_channel

#         print(f'fc1_input_dim: {self.fc1_input_dim}')
        
        
        self.fc1_out = trial.suggest_categorical('concat_fc1_out', [128, 256, 512])
#         self.fc1_out = 512
    
        # add the rcm feature dimension here as well (5*5+2)*3+2 = 83
        self.fc1 = nn.Linear(self.fc1_input_dim, self.fc1_out)
        
        self.fc1_bn = nn.BatchNorm1d(self.fc1_out)

        dropout_rate_fc1 = trial.suggest_categorical("concat_dropout_rate_fc1",  [0, 0.1, 0.2, 0.4])
#         dropout_rate_fc1 = 0
        self.drop_nn1 = nn.Dropout(p=dropout_rate_fc1)

        # fc layer2
        # use dimension output with nn.CrossEntropyLoss()
        self.fc2_out = trial.suggest_categorical('concat_fc2_out', [4, 8, 16, 32])
#         self.fc2_out = 8
        self.fc2 = nn.Linear(self.fc1_out, self.fc2_out)

        self.fc2_bn = nn.BatchNorm1d(self.fc2_out)

        dropout_rate_fc2 = trial.suggest_categorical("concat_dropout_rate_fc2",[0, 0.1, 0.2, 0.4])
#         dropout_rate_fc2 = 0
    
        self.drop_nn2 = nn.Dropout(p=dropout_rate_fc2)

        self.fc3 = nn.Linear(self.fc2_out, 2)
        

    def forward(self, rcm_flanking, rcm_upper, rcm_lower):
        
        x1 = self.cnn_flanking(rcm_flanking)

        x2 = self.cnn_upper(rcm_upper)
        
        x3 = self.cnn_lower(rcm_lower)
        
        x = torch.cat((x1,x2,x3), dim=1)
    
        # feed the concatenated feature to fc1
        out = self.fc1(x)
        out = self.drop_nn1(torch.relu(self.fc1_bn(out)))
        out = self.fc2(out)
        out = self.drop_nn2(torch.relu(self.fc2_bn(out)))
        out = self.fc3(out)
        return out


In [3]:
def rcm_flankingWithin_large_windows_optuna(num_trial):
    
    ## specify different kmer length to get the training data of the rcm score for that kmer 
    ### just change this number to 10, 20, 40 and 80 to get the model performance for different kmer length

    study = optuna.create_study(direction='maximize')

    ### where to save the 3-fold CV validation acc based on the rcm score and mlp

    val_acc_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingWithin_large_windows/9000/val_acc_cv3'
    ### where to save the best model in the 3-fold CV 
    ### wehre to save the detailed optuna results
    optuna_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingWithin_large_windows/9000/optuna'
    
    
    BS_LS_coordinates_path = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/updated_data/BS_LS_coordinates_final.csv'
    hg19_seq_dict_json_path = '/home/wangc90/circRNA/circRNA_Data/hg19_seq/hg19_seq_dict.json'
    flanking_dict_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/'
    bs_ls_dataset = BS_LS_DataSet_Prep(BS_LS_coordinates_path=BS_LS_coordinates_path,
                                   hg19_seq_dict_json_path=hg19_seq_dict_json_path,
                                   flanking_dict_folder=flanking_dict_folder,
                                   flanking_junction_bps=100,
                                   flanking_intron_bps=5000,
                                   training_size=9000)


    ## generate the junction and flanking intron dict
    bs_ls_dataset.get_junction_flanking_intron_seq()

    train_key_1, _, test_keys = bs_ls_dataset.get_train_test_keys()


    rcm_scores_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/rcm_scores/'

    ### try with rcm features
    _, _, train_torch_flanking_rcm, train_torch_upper_rcm,\
    train_torch_lower_rcm, train_torch_labels = bs_ls_dataset.seq_to_tensor(data_keys=train_key_1,\
                                                                            rcm_folder=rcm_scores_folder,\
                                                                            is_rcm=True,\
                                                                            is_upper_lower_concat=False)
    
#     print(train_torch_flanking_rcm.shape)

    RCM_kmer_Score_dataset = RCM_Score(flanking_only=False,
                                       flanking_rcm=train_torch_flanking_rcm,\
                                       upper_rcm=train_torch_upper_rcm,\
                                       lower_rcm=train_torch_lower_rcm,\
                                       label=train_torch_labels)
    
    print(len(RCM_kmer_Score_dataset))
    
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=1, n_startup_trials=10),
                                direction='maximize')


    study.optimize(Objective_CV(cv=3, model= RCM_optuna_concate, 
                                dataset=RCM_kmer_Score_dataset,
                                val_acc_folder=val_acc_folder), n_trials=num_trial, gc_after_trial=True)


    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    with open(optuna_folder+'/optuna.txt', 'a') as f:
        f.write("Study statistiqcs: \n")
        f.write(f"Number of finished trials: {len(study.trials)}\n")
        f.write(f"Number of pruned trials: {len(pruned_trials)}\n")
        f.write(f"Number of complete trials: {len(complete_trials)}\n")

        f.write("Best trial:\n")
        trial = study.best_trial
        f.write(f"Value: {trial.value}\n")
        f.write("Params:\n")
        for key, value in trial.params.items():
            f.write(f"{key}:{value}\n")

    df = study.trials_dataframe().drop(['state','datetime_start','datetime_complete','duration','number'], axis=1)
    df.to_csv(optuna_folder + '/optuna.csv', sep='\t', index=None)

In [None]:
rcm_flankingWithin_large_windows_optuna(num_trial=500)

[32m[I 2023-11-22 19:16:08,080][0m A new study created in memory with name: no-name-291bfcac-94ac-46ad-8f1e-ed8f38f217c3[0m


chr5|138837130|138837392|- has N in the extracted junctions, belongs to BS
There are 0 overlapped flanking sequence from BS and LS  
There are 7 repeated BS sequences
There are 2 repeated LS sequences


[32m[I 2023-11-22 19:18:30,720][0m A new study created in memory with name: no-name-5d772124-aba4-46ef-b297-d8fbb35223d0[0m


18000
fold 1, epoch 20, val loss 9.291157007217407 val accuracy 0.5987
fold 1, epoch 40, val loss 10.905982732772827 val accuracy 0.5978
fold 1, epoch 60, val loss 11.891278982162476 val accuracy 0.6042
fold 1, epoch 80, val loss 12.870460987091064 val accuracy 0.6052
fold 1, epoch 100, val loss 13.779450178146362 val accuracy 0.6083
fold 1, epoch 120, val loss 14.70875597000122 val accuracy 0.6065
fold 2, epoch 20, val loss 9.554250061511993 val accuracy 0.6037
fold 2, epoch 40, val loss 10.882030069828033 val accuracy 0.6115
fold 2, epoch 60, val loss 12.130437672138214 val accuracy 0.6148
fold 2, epoch 80, val loss 13.293223977088928 val accuracy 0.6073
fold 2, epoch 100, val loss 14.06904399394989 val accuracy 0.6097
fold 2, epoch 120, val loss 15.212745070457458 val accuracy 0.5968
fold 3, epoch 20, val loss 9.260751247406006 val accuracy 0.6123
fold 3, epoch 40, val loss 10.35059130191803 val accuracy 0.6137
fold 3, epoch 60, val loss 11.471743404865265 val accuracy 0.6113
fold 3

[32m[I 2023-11-22 19:37:17,208][0m Trial 0 finished with value: 0.6060333333333333 and parameters: {'lr': 0.00019722237181510606, 'l2_lambda': 2.424020142512239e-09, 'batch_size': 512, 'epochs': 120, 'flanking_out_channel1': 256, 'flanking_out_channel2': 256, 'upper_out_channel1': 256, 'upper_out_channel2': 512, 'lower_out_channel1': 512, 'lower_out_channel2': 128, 'concat_fc1_out': 512, 'concat_dropout_rate_fc1': 0, 'concat_fc2_out': 16, 'concat_dropout_rate_fc2': 0}. Best is trial 0 with value: 0.6060333333333333.[0m


fold 3, epoch 120, val loss 13.583281993865967 val accuracy 0.6148
fold 1, epoch 20, val loss 8.292022347450256 val accuracy 0.5763
fold 1, epoch 40, val loss 8.705272138118744 val accuracy 0.5878
fold 1, epoch 60, val loss 9.248005330562592 val accuracy 0.5885
fold 1, epoch 80, val loss 9.648603737354279 val accuracy 0.5922
fold 2, epoch 20, val loss 8.277420341968536 val accuracy 0.5808
fold 2, epoch 40, val loss 8.431080222129822 val accuracy 0.5978
fold 2, epoch 60, val loss 8.65621268749237 val accuracy 0.5973
fold 2, epoch 80, val loss 9.02912813425064 val accuracy 0.6017
fold 3, epoch 20, val loss 8.341042220592499 val accuracy 0.5588
fold 3, epoch 40, val loss 8.302250444889069 val accuracy 0.5885
fold 3, epoch 60, val loss 8.708787977695465 val accuracy 0.5933
fold 3, epoch 80, val loss 8.945180773735046 val accuracy 0.5998


[32m[I 2023-11-22 19:50:41,540][0m Trial 1 finished with value: 0.5966 and parameters: {'lr': 3.516763123420217e-05, 'l2_lambda': 4.710961880803116e-07, 'batch_size': 512, 'epochs': 90, 'flanking_out_channel1': 256, 'flanking_out_channel2': 512, 'upper_out_channel1': 512, 'upper_out_channel2': 256, 'lower_out_channel1': 256, 'lower_out_channel2': 128, 'concat_fc1_out': 256, 'concat_dropout_rate_fc1': 0.4, 'concat_fc2_out': 16, 'concat_dropout_rate_fc2': 0}. Best is trial 0 with value: 0.6060333333333333.[0m


fold 1, epoch 20, val loss 33.52337563037872 val accuracy 0.503
fold 1, epoch 40, val loss 32.8447260260582 val accuracy 0.5377
fold 1, epoch 60, val loss 32.06057530641556 val accuracy 0.5877
fold 1, epoch 80, val loss 32.405984699726105 val accuracy 0.6048
fold 1, epoch 100, val loss 32.68707633018494 val accuracy 0.6082
fold 1, epoch 120, val loss 33.34649050235748 val accuracy 0.6133
fold 1, epoch 140, val loss 34.948686361312866 val accuracy 0.6133
fold 2, epoch 20, val loss 31.429759204387665 val accuracy 0.6037
fold 2, epoch 40, val loss 31.521911203861237 val accuracy 0.6163
fold 2, epoch 60, val loss 31.763312876224518 val accuracy 0.6205
fold 2, epoch 80, val loss 32.58815848827362 val accuracy 0.6215
fold 2, epoch 100, val loss 33.838620364665985 val accuracy 0.619
fold 2, epoch 120, val loss 34.703016579151154 val accuracy 0.6255
fold 2, epoch 140, val loss 36.5053209066391 val accuracy 0.6263
fold 3, epoch 20, val loss 31.240939557552338 val accuracy 0.6002
fold 3, epoch 4

[32m[I 2023-11-22 20:22:41,192][0m Trial 2 finished with value: 0.6233333333333333 and parameters: {'lr': 1.961028811739023e-05, 'l2_lambda': 2.1621350941963873e-09, 'batch_size': 128, 'epochs': 150, 'flanking_out_channel1': 512, 'flanking_out_channel2': 256, 'upper_out_channel1': 128, 'upper_out_channel2': 128, 'lower_out_channel1': 512, 'lower_out_channel2': 256, 'concat_fc1_out': 128, 'concat_dropout_rate_fc1': 0.2, 'concat_fc2_out': 4, 'concat_dropout_rate_fc2': 0}. Best is trial 2 with value: 0.6233333333333333.[0m


fold 1, epoch 20, val loss 30.60913920402527 val accuracy 0.6195
fold 1, epoch 40, val loss 36.65914511680603 val accuracy 0.6233
fold 1, epoch 60, val loss 48.23774206638336 val accuracy 0.6275
fold 1, epoch 80, val loss 43.32447326183319 val accuracy 0.6183
fold 2, epoch 20, val loss 30.37846803665161 val accuracy 0.6103
fold 2, epoch 40, val loss 36.512526750564575 val accuracy 0.6208
fold 2, epoch 60, val loss 40.31506025791168 val accuracy 0.6278
fold 2, epoch 80, val loss 43.25117754936218 val accuracy 0.6288
fold 3, epoch 20, val loss 33.492743372917175 val accuracy 0.5947
fold 3, epoch 40, val loss 38.17717707157135 val accuracy 0.6198
fold 3, epoch 60, val loss 41.252912521362305 val accuracy 0.6143
fold 3, epoch 80, val loss 44.872756481170654 val accuracy 0.6152


[32m[I 2023-11-22 20:38:41,718][0m Trial 3 finished with value: 0.6106666666666666 and parameters: {'lr': 0.00024178478085221878, 'l2_lambda': 1.577314004981381e-08, 'batch_size': 256, 'epochs': 90, 'flanking_out_channel1': 128, 'flanking_out_channel2': 512, 'upper_out_channel1': 128, 'upper_out_channel2': 256, 'lower_out_channel1': 512, 'lower_out_channel2': 512, 'concat_fc1_out': 256, 'concat_dropout_rate_fc1': 0.4, 'concat_fc2_out': 32, 'concat_dropout_rate_fc2': 0.1}. Best is trial 2 with value: 0.6233333333333333.[0m


fold 1, epoch 20, val loss 61.96321952342987 val accuracy 0.6167
fold 1, epoch 40, val loss 63.23557037115097 val accuracy 0.6292
fold 1, epoch 60, val loss 68.92352056503296 val accuracy 0.621
fold 2, epoch 20, val loss 62.18353056907654 val accuracy 0.6075
fold 2, epoch 40, val loss 63.95232081413269 val accuracy 0.6193
fold 2, epoch 60, val loss 69.37877708673477 val accuracy 0.6222
fold 3, epoch 20, val loss 62.34079748392105 val accuracy 0.6053
fold 3, epoch 40, val loss 64.38355469703674 val accuracy 0.6283


[32m[I 2023-11-22 20:57:41,043][0m Trial 4 finished with value: 0.6245666666666666 and parameters: {'lr': 2.3465236024401466e-05, 'l2_lambda': 9.415397952224257e-07, 'batch_size': 64, 'epochs': 60, 'flanking_out_channel1': 256, 'flanking_out_channel2': 256, 'upper_out_channel1': 128, 'upper_out_channel2': 512, 'lower_out_channel1': 128, 'lower_out_channel2': 512, 'concat_fc1_out': 512, 'concat_dropout_rate_fc1': 0.4, 'concat_fc2_out': 4, 'concat_dropout_rate_fc2': 0}. Best is trial 4 with value: 0.6245666666666666.[0m


fold 3, epoch 60, val loss 70.68210887908936 val accuracy 0.6305
fold 1, epoch 20, val loss 64.4973514676094 val accuracy 0.5713
fold 1, epoch 40, val loss 66.75803780555725 val accuracy 0.5785
fold 1, epoch 60, val loss 68.2635503411293 val accuracy 0.5893
fold 2, epoch 20, val loss 63.84495681524277 val accuracy 0.5795
fold 2, epoch 40, val loss 65.05863380432129 val accuracy 0.5853
fold 2, epoch 60, val loss 67.36334896087646 val accuracy 0.593
fold 3, epoch 20, val loss 66.8430939912796 val accuracy 0.5683
fold 3, epoch 40, val loss 70.78991287946701 val accuracy 0.5802


[32m[I 2023-11-22 21:16:30,362][0m Trial 5 finished with value: 0.5907666666666667 and parameters: {'lr': 1.1251588416523185e-05, 'l2_lambda': 1.483896219361958e-07, 'batch_size': 64, 'epochs': 60, 'flanking_out_channel1': 256, 'flanking_out_channel2': 256, 'upper_out_channel1': 256, 'upper_out_channel2': 128, 'lower_out_channel1': 512, 'lower_out_channel2': 128, 'concat_fc1_out': 256, 'concat_dropout_rate_fc1': 0, 'concat_fc2_out': 8, 'concat_dropout_rate_fc2': 0.1}. Best is trial 4 with value: 0.6245666666666666.[0m


fold 3, epoch 60, val loss 73.65111100673676 val accuracy 0.59
fold 1, epoch 20, val loss 123.58109498023987 val accuracy 0.6173
fold 1, epoch 40, val loss 132.07362908124924 val accuracy 0.6373
fold 1, epoch 60, val loss 159.50348231196404 val accuracy 0.6403
fold 1, epoch 80, val loss 182.60602521896362 val accuracy 0.6542
fold 1, epoch 100, val loss 221.18193820118904 val accuracy 0.6495
fold 1, epoch 120, val loss 259.63651460409164 val accuracy 0.6352
fold 1, epoch 140, val loss 275.21646854281425 val accuracy 0.6495
fold 2, epoch 20, val loss 121.94330984354019 val accuracy 0.6173
fold 2, epoch 40, val loss 127.466105312109 val accuracy 0.6475
fold 2, epoch 60, val loss 145.39070004224777 val accuracy 0.6488
fold 2, epoch 80, val loss 175.1017121374607 val accuracy 0.6575
fold 2, epoch 100, val loss 207.82079684734344 val accuracy 0.6545
fold 2, epoch 120, val loss 238.94467881321907 val accuracy 0.6543
