In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
import pandas as pd
import numpy as np
import os
import random
from collections import defaultdict, Counter
from itertools import combinations
import json
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold
import optuna
from torchmetrics.classification import F1Score
import pickle

import sys
### import Dataset prepartion and model training classes from BS_LS_scripts folder
sys.path.insert(1, '/home/wangc90/circRNA/circRNA_Data/BS_LS_scripts/')
from BS_LS_DataSet import BS_LS_DataSet_Prep, RCM_Score
from BS_LS_Training_Base_models_0 import Objective, Objective_CV


In [2]:
class RCM_optuna_flanking(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the flanking introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_flanking, self).__init__()
        
        # convlayer 1
        self.out_channel1 = trial.suggest_categorical('flanking_out_channel1', [128, 256, 512])
#         self.out_channel1 = 128

#         kernel_size1 = 5

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=5, stride=5, padding=0)
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
        
        self.out_channel2 = trial.suggest_categorical('flanking_out_channel2', [128, 256, 512])
#         self.out_channel2 = 32
        
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
        
        self.conv2_out_dim = 5
        
        
        self.fc1_input_dim = self.conv2_out_dim * self.out_channel2 

#         print(f'fc1_input_dim: {self.fc1_input_dim}')
        
        
        self.fc1_out = trial.suggest_categorical('fc1_out', [128, 256, 512])
#         self.fc1_out = 512
    
        # add the rcm feature dimension here as well (5*5+2)*3+2 = 83
        self.fc1 = nn.Linear(self.fc1_input_dim, self.fc1_out)
        
        self.fc1_bn = nn.BatchNorm1d(self.fc1_out)

        dropout_rate_fc1 = trial.suggest_categorical("concat_dropout_rate_fc1",  [0, 0.1, 0.2, 0.4])
        self.drop_nn1 = nn.Dropout(p=dropout_rate_fc1)

        # fc layer2
        # use dimension output with nn.CrossEntropyLoss()
        self.fc2_out = trial.suggest_categorical('fc2_out', [4, 8, 16, 32])
#         self.fc2_out = 8
        self.fc2 = nn.Linear(self.fc1_out, self.fc2_out)

        self.fc2_bn = nn.BatchNorm1d(self.fc2_out)

        dropout_rate_fc2 = trial.suggest_categorical("concat_dropout_rate_fc2",[0, 0.1, 0.2, 0.4])
    
        self.drop_nn2 = nn.Dropout(p=dropout_rate_fc2)

        self.fc3 = nn.Linear(self.fc2_out, 2)
        

    def forward(self, x):
        
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
    
        out = self.fc1(out)
        out = self.drop_nn1(torch.relu(self.fc1_bn(out)))
        out = self.fc2(out)
        out = self.drop_nn2(torch.relu(self.fc2_bn(out)))
        out = self.fc3(out)
        return out


In [3]:
def rcm_flankingOnly_small_windows_optuna(num_trial):
    
    ## specify different kmer length to get the training data of the rcm score for that kmer 
    ### just change this number to 10, 20, 40 and 80 to get the model performance for different kmer length

    study = optuna.create_study(direction='maximize')

    ### where to save the 3-fold CV validation acc based on the rcm score and mlp

    val_acc_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingOnly_small_windows/8000/val_acc_cv3'
    ### where to save the best model in the 3-fold CV 
    ### wehre to save the detailed optuna results
    optuna_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingOnly_small_windows/8000/optuna'
    
    
    BS_LS_coordinates_path = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/updated_data/BS_LS_coordinates_final.csv'
    hg19_seq_dict_json_path = '/home/wangc90/circRNA/circRNA_Data/hg19_seq/hg19_seq_dict.json'
    flanking_dict_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/'
    bs_ls_dataset = BS_LS_DataSet_Prep(BS_LS_coordinates_path=BS_LS_coordinates_path,
                                   hg19_seq_dict_json_path=hg19_seq_dict_json_path,
                                   flanking_dict_folder=flanking_dict_folder,
                                   flanking_junction_bps=100,
                                   flanking_intron_bps=5000,
                                   training_size=8000)


    ## generate the junction and flanking intron dict
    bs_ls_dataset.get_junction_flanking_intron_seq()

    train_key_1, _, test_keys = bs_ls_dataset.get_train_test_keys()

    rcm_scores_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/rcm_scores/'

    ### only use flanking_rcm_scores
    _, _, train_torch_flanking_rcm, _,\
    _, train_torch_labels = bs_ls_dataset.seq_to_tensor(data_keys=train_key_1, rcm_folder=rcm_scores_folder,\
                                                        is_rcm=True, is_upper_lower_concat=False)
    
#     print(train_torch_flanking_rcm.shape)

    RCM_kmer_Score_dataset = RCM_Score(flanking_only=True,
                                       flanking_rcm=train_torch_flanking_rcm, 
                                       upper_rcm=None,\
                                       lower_rcm=None, label=train_torch_labels)
    
    print(len(RCM_kmer_Score_dataset))
    
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=1, n_startup_trials=10),
                                direction='maximize')

    study.optimize(Objective_CV(cv=3,model= RCM_optuna_flanking, 
                                dataset=RCM_kmer_Score_dataset,
                                val_acc_folder=val_acc_folder), n_trials=num_trial, gc_after_trial=True)


    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    with open(optuna_folder+'/optuna.txt', 'a') as f:
        f.write("Study statistiqcs: \n")
        f.write(f"Number of finished trials: {len(study.trials)}\n")
        f.write(f"Number of pruned trials: {len(pruned_trials)}\n")
        f.write(f"Number of complete trials: {len(complete_trials)}\n")

        f.write("Best trial:\n")
        trial = study.best_trial
        f.write(f"Value: {trial.value}\n")
        f.write("Params:\n")
        for key, value in trial.params.items():
            f.write(f"{key}:{value}\n")

    df = study.trials_dataframe().drop(['state','datetime_start','datetime_complete','duration','number'], axis=1)
    df.to_csv(optuna_folder + '/optuna.csv', sep='\t', index=None)

In [None]:
rcm_flankingOnly_small_windows_optuna(num_trial=500)

[32m[I 2023-11-22 19:16:24,722][0m A new study created in memory with name: no-name-c527781f-4c57-4c51-a808-7d584ce8940f[0m


chr5|138837130|138837392|- has N in the extracted junctions, belongs to BS
There are 0 overlapped flanking sequence from BS and LS  
There are 7 repeated BS sequences
There are 2 repeated LS sequences


[32m[I 2023-11-22 19:18:15,919][0m A new study created in memory with name: no-name-1fdd14cd-5db4-48d5-b56f-7b4ddda9411a[0m


16000
fold 1, epoch 20, val loss 19.150680541992188 val accuracy 0.5206
fold 1, epoch 40, val loss 27.26481556892395 val accuracy 0.5217
fold 1, epoch 60, val loss 32.45074534416199 val accuracy 0.5129
fold 2, epoch 20, val loss 19.996910214424133 val accuracy 0.5091
fold 2, epoch 40, val loss 27.346073508262634 val accuracy 0.513
fold 2, epoch 60, val loss 32.75779449939728 val accuracy 0.51
fold 3, epoch 20, val loss 19.98124873638153 val accuracy 0.5149
fold 3, epoch 40, val loss 27.53920841217041 val accuracy 0.5029


[32m[I 2023-11-22 19:27:37,861][0m Trial 0 finished with value: 0.5123666666666666 and parameters: {'lr': 0.0003473762017251187, 'l2_lambda': 4.135726959669756e-07, 'batch_size': 256, 'epochs': 60, 'flanking_out_channel1': 512, 'flanking_out_channel2': 128, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0.1, 'fc2_out': 4, 'concat_dropout_rate_fc2': 0.4}. Best is trial 0 with value: 0.5123666666666666.[0m


fold 3, epoch 60, val loss 50.5105414390564 val accuracy 0.5142
fold 1, epoch 20, val loss 127.78898602724075 val accuracy 0.5272
fold 1, epoch 40, val loss 155.79700911045074 val accuracy 0.5187
fold 1, epoch 60, val loss 199.8129741549492 val accuracy 0.521
fold 1, epoch 80, val loss 237.26033401489258 val accuracy 0.5148
fold 2, epoch 20, val loss 135.61321812868118 val accuracy 0.5104
fold 2, epoch 40, val loss 150.02754080295563 val accuracy 0.5127
fold 2, epoch 60, val loss 176.12221068143845 val accuracy 0.5207
fold 2, epoch 80, val loss 224.8861272931099 val accuracy 0.5218
fold 3, epoch 20, val loss 126.83238756656647 val accuracy 0.5322
fold 3, epoch 40, val loss 155.90197545289993 val accuracy 0.5211
fold 3, epoch 60, val loss 192.30604356527328 val accuracy 0.5211
fold 3, epoch 80, val loss 234.7103355526924 val accuracy 0.5293


[32m[I 2023-11-22 20:03:19,749][0m Trial 1 finished with value: 0.5191333333333333 and parameters: {'lr': 1.548278365193729e-05, 'l2_lambda': 2.3203560337514604e-07, 'batch_size': 32, 'epochs': 90, 'flanking_out_channel1': 256, 'flanking_out_channel2': 512, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0, 'fc2_out': 8, 'concat_dropout_rate_fc2': 0}. Best is trial 1 with value: 0.5191333333333333.[0m


fold 1, epoch 20, val loss 7.6445392370224 val accuracy 0.5075
fold 1, epoch 40, val loss 7.6772913336753845 val accuracy 0.5082
fold 1, epoch 60, val loss 7.704340636730194 val accuracy 0.5092
fold 1, epoch 80, val loss 7.731413125991821 val accuracy 0.5137
fold 1, epoch 100, val loss 7.764158546924591 val accuracy 0.5137
fold 1, epoch 120, val loss 7.772045433521271 val accuracy 0.5172
fold 2, epoch 20, val loss 8.289597988128662 val accuracy 0.5181
fold 2, epoch 40, val loss 8.450882017612457 val accuracy 0.5166
fold 2, epoch 60, val loss 8.553983092308044 val accuracy 0.5168
fold 2, epoch 80, val loss 8.580121338367462 val accuracy 0.519
fold 2, epoch 100, val loss 8.823643863201141 val accuracy 0.5153
fold 2, epoch 120, val loss 8.792267382144928 val accuracy 0.5162
fold 3, epoch 20, val loss 8.455825865268707 val accuracy 0.5106
fold 3, epoch 40, val loss 8.677291512489319 val accuracy 0.5091
fold 3, epoch 60, val loss 8.873716711997986 val accuracy 0.5115
fold 3, epoch 80, val l

[32m[I 2023-11-22 20:19:31,067][0m Trial 2 finished with value: 0.5166 and parameters: {'lr': 2.8679158962618775e-05, 'l2_lambda': 3.2736984913195935e-07, 'batch_size': 512, 'epochs': 120, 'flanking_out_channel1': 256, 'flanking_out_channel2': 256, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0, 'fc2_out': 4, 'concat_dropout_rate_fc2': 0.1}. Best is trial 1 with value: 0.5191333333333333.[0m


fold 3, epoch 120, val loss 9.150561690330505 val accuracy 0.5164
fold 1, epoch 20, val loss 173.29015696048737 val accuracy 0.4948
fold 1, epoch 40, val loss 213.32481837272644 val accuracy 0.4993
fold 1, epoch 60, val loss 239.92540955543518 val accuracy 0.4989
fold 1, epoch 80, val loss 252.07414424419403 val accuracy 0.5051
fold 1, epoch 100, val loss 269.1898331642151 val accuracy 0.5082
fold 1, epoch 120, val loss 290.48957991600037 val accuracy 0.5097
fold 2, epoch 20, val loss 168.05505108833313 val accuracy 0.5207
fold 2, epoch 40, val loss 210.4926106929779 val accuracy 0.5158
fold 2, epoch 60, val loss 234.9144184589386 val accuracy 0.5142
fold 2, epoch 80, val loss 245.49391841888428 val accuracy 0.5155
fold 2, epoch 100, val loss 254.50145661830902 val accuracy 0.5185
fold 2, epoch 120, val loss 265.76896142959595 val accuracy 0.5089
fold 3, epoch 20, val loss 162.73945713043213 val accuracy 0.5078
fold 3, epoch 40, val loss 207.12280941009521 val accuracy 0.5112
fold 3, e

[32m[I 2023-11-22 20:50:15,558][0m Trial 3 finished with value: 0.5129333333333334 and parameters: {'lr': 0.0003824578003141549, 'l2_lambda': 6.471144391734089e-08, 'batch_size': 64, 'epochs': 120, 'flanking_out_channel1': 512, 'flanking_out_channel2': 512, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0.2, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0}. Best is trial 1 with value: 0.5191333333333333.[0m


fold 3, epoch 120, val loss 267.5035911798477 val accuracy 0.5202
fold 1, epoch 20, val loss 12.110345482826233 val accuracy 0.5021
fold 1, epoch 40, val loss 20.35155713558197 val accuracy 0.5077
fold 1, epoch 60, val loss 26.27784562110901 val accuracy 0.512
fold 2, epoch 20, val loss 11.623866736888885 val accuracy 0.5158
fold 2, epoch 40, val loss 19.51620316505432 val accuracy 0.5153
fold 2, epoch 60, val loss 24.264650583267212 val accuracy 0.5158
fold 3, epoch 20, val loss 10.81858229637146 val accuracy 0.5312
fold 3, epoch 40, val loss 18.26066744327545 val accuracy 0.5263


[32m[I 2023-11-22 20:58:21,835][0m Trial 4 finished with value: 0.5167333333333334 and parameters: {'lr': 0.00024104447440313916, 'l2_lambda': 4.703696183570239e-08, 'batch_size': 512, 'epochs': 60, 'flanking_out_channel1': 128, 'flanking_out_channel2': 256, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0.4, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0.2}. Best is trial 1 with value: 0.5191333333333333.[0m


fold 3, epoch 60, val loss 23.453925132751465 val accuracy 0.5224
fold 1, epoch 20, val loss 147.58953595161438 val accuracy 0.5088
fold 1, epoch 40, val loss 210.8513684272766 val accuracy 0.5088
fold 1, epoch 60, val loss 219.226988196373 val accuracy 0.5165
fold 2, epoch 20, val loss 152.83612966537476 val accuracy 0.5179
fold 2, epoch 40, val loss 201.24972128868103 val accuracy 0.5065
fold 2, epoch 60, val loss 226.62519121170044 val accuracy 0.5117
fold 3, epoch 20, val loss 141.32909727096558 val accuracy 0.534
fold 3, epoch 40, val loss 185.91094744205475 val accuracy 0.5361


[32m[I 2023-11-22 21:13:05,699][0m Trial 5 finished with value: 0.5187333333333334 and parameters: {'lr': 0.0001318387590263115, 'l2_lambda': 4.01695463981504e-09, 'batch_size': 64, 'epochs': 60, 'flanking_out_channel1': 128, 'flanking_out_channel2': 128, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0}. Best is trial 1 with value: 0.5191333333333333.[0m


fold 3, epoch 60, val loss 209.26243937015533 val accuracy 0.528
fold 1, epoch 20, val loss 15.243024468421936 val accuracy 0.533
fold 1, epoch 40, val loss 15.93897533416748 val accuracy 0.5322
fold 1, epoch 60, val loss 16.268383026123047 val accuracy 0.5347
fold 2, epoch 20, val loss 15.447180271148682 val accuracy 0.5224
fold 2, epoch 40, val loss 15.652762651443481 val accuracy 0.525
fold 2, epoch 60, val loss 16.509978890419006 val accuracy 0.5192
fold 3, epoch 20, val loss 15.851193726062775 val accuracy 0.4975
fold 3, epoch 40, val loss 15.744989693164825 val accuracy 0.5093


[32m[I 2023-11-22 21:22:22,107][0m Trial 6 finished with value: 0.523 and parameters: {'lr': 3.549905707350068e-05, 'l2_lambda': 1.291013155276939e-09, 'batch_size': 256, 'epochs': 60, 'flanking_out_channel1': 512, 'flanking_out_channel2': 256, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0.1, 'fc2_out': 4, 'concat_dropout_rate_fc2': 0.2}. Best is trial 6 with value: 0.523.[0m


fold 3, epoch 60, val loss 15.728446364402771 val accuracy 0.5151
fold 1, epoch 20, val loss 16.985926866531372 val accuracy 0.5129
fold 1, epoch 40, val loss 19.36303871870041 val accuracy 0.5135
fold 1, epoch 60, val loss 22.70515286922455 val accuracy 0.5221
fold 2, epoch 20, val loss 16.957006812095642 val accuracy 0.5108
fold 2, epoch 40, val loss 19.878898859024048 val accuracy 0.5082
fold 2, epoch 60, val loss 23.32651960849762 val accuracy 0.5162
fold 3, epoch 20, val loss 16.591112077236176 val accuracy 0.517
fold 3, epoch 40, val loss 19.46333384513855 val accuracy 0.5252


[32m[I 2023-11-22 21:32:29,941][0m Trial 7 finished with value: 0.5205000000000001 and parameters: {'lr': 0.00014837349976193467, 'l2_lambda': 1.2335539939492151e-07, 'batch_size': 256, 'epochs': 60, 'flanking_out_channel1': 512, 'flanking_out_channel2': 512, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0.2, 'fc2_out': 4, 'concat_dropout_rate_fc2': 0.1}. Best is trial 6 with value: 0.523.[0m


fold 3, epoch 60, val loss 22.619086265563965 val accuracy 0.5232
fold 1, epoch 20, val loss 59.34304255247116 val accuracy 0.5358
fold 1, epoch 40, val loss 63.892914831638336 val accuracy 0.5253
fold 1, epoch 60, val loss 74.75419074296951 val accuracy 0.5231
fold 2, epoch 20, val loss 58.392478942871094 val accuracy 0.5481
fold 2, epoch 40, val loss 62.68725502490997 val accuracy 0.5427
fold 2, epoch 60, val loss 73.66594642400742 val accuracy 0.5329
fold 3, epoch 20, val loss 58.94516450166702 val accuracy 0.5391
fold 3, epoch 40, val loss 63.619608640670776 val accuracy 0.5325


[32m[I 2023-11-22 21:47:25,952][0m Trial 8 finished with value: 0.5247666666666667 and parameters: {'lr': 1.1352044504063636e-05, 'l2_lambda': 2.2974374544965392e-07, 'batch_size': 64, 'epochs': 60, 'flanking_out_channel1': 128, 'flanking_out_channel2': 512, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0.4, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0.1}. Best is trial 8 with value: 0.5247666666666667.[0m


fold 3, epoch 60, val loss 74.3215982913971 val accuracy 0.5183
fold 1, epoch 20, val loss 14.932204246520996 val accuracy 0.5092
fold 1, epoch 40, val loss 15.011396825313568 val accuracy 0.5154
fold 1, epoch 60, val loss 15.189367294311523 val accuracy 0.5167
fold 1, epoch 80, val loss 15.319118678569794 val accuracy 0.5163
fold 2, epoch 20, val loss 14.712533116340637 val accuracy 0.495
fold 2, epoch 40, val loss 14.83323323726654 val accuracy 0.495
fold 2, epoch 60, val loss 14.884495317935944 val accuracy 0.495
fold 2, epoch 80, val loss 15.051126956939697 val accuracy 0.495
fold 3, epoch 20, val loss 14.833477139472961 val accuracy 0.514
fold 3, epoch 40, val loss 14.826746106147766 val accuracy 0.5355
fold 3, epoch 60, val loss 14.820115685462952 val accuracy 0.5322
fold 3, epoch 80, val loss 14.999350011348724 val accuracy 0.535


[32m[I 2023-11-22 22:01:38,733][0m Trial 9 finished with value: 0.5140333333333333 and parameters: {'lr': 1.2041626787161341e-05, 'l2_lambda': 4.334798450225049e-09, 'batch_size': 256, 'epochs': 90, 'flanking_out_channel1': 512, 'flanking_out_channel2': 256, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0.2, 'fc2_out': 4, 'concat_dropout_rate_fc2': 0}. Best is trial 8 with value: 0.5247666666666667.[0m


fold 1, epoch 20, val loss 37.38149803876877 val accuracy 0.5167
fold 2, epoch 20, val loss 34.9550096988678 val accuracy 0.5299
