In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
import pandas as pd
import numpy as np
import os
import random
from collections import defaultdict, Counter
from itertools import combinations
import json
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold
import optuna
from torchmetrics.classification import F1Score
import pickle

import sys
### import Dataset prepartion and model training classes from BS_LS_scripts folder
sys.path.insert(1, '/home/wangc90/circRNA/circRNA_Data/BS_LS_scripts/')
from BS_LS_DataSet_2 import BS_LS_DataSet_Prep, RCM_Score
from BS_LS_Training_Base_models_0 import Objective, Objective_CV


In [2]:
class RCM_optuna_flanking(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the flanking introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_flanking, self).__init__()
        
        # convlayer 1
        self.out_channel1 = trial.suggest_categorical('flanking_out_channel1', [128, 256, 512])
#         self.out_channel1 = 128

#         kernel_size1 = 5

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=5, stride=5, padding=0)
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
        
        self.out_channel2 = trial.suggest_categorical('flanking_out_channel2', [128, 256, 512])
#         self.out_channel2 = 32
        
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
        
        self.conv2_out_dim = 5
        
        
        self.fc1_input_dim = self.conv2_out_dim * self.out_channel2 

#         print(f'fc1_input_dim: {self.fc1_input_dim}')
        
        
        self.fc1_out = trial.suggest_categorical('fc1_out', [128, 256, 512])
#         self.fc1_out = 512
    
        # add the rcm feature dimension here as well (5*5+2)*3+2 = 83
        self.fc1 = nn.Linear(self.fc1_input_dim, self.fc1_out)
        
        self.fc1_bn = nn.BatchNorm1d(self.fc1_out)

        dropout_rate_fc1 = trial.suggest_categorical("concat_dropout_rate_fc1",  [0, 0.1, 0.2, 0.4])
        self.drop_nn1 = nn.Dropout(p=dropout_rate_fc1)

        # fc layer2
        # use dimension output with nn.CrossEntropyLoss()
        self.fc2_out = trial.suggest_categorical('fc2_out', [4, 8, 16, 32])
#         self.fc2_out = 8
        self.fc2 = nn.Linear(self.fc1_out, self.fc2_out)

        self.fc2_bn = nn.BatchNorm1d(self.fc2_out)

        dropout_rate_fc2 = trial.suggest_categorical("concat_dropout_rate_fc2",[0, 0.1, 0.2, 0.4])
    
        self.drop_nn2 = nn.Dropout(p=dropout_rate_fc2)

        self.fc3 = nn.Linear(self.fc2_out, 2)
        

    def forward(self, x):
        
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
    
        out = self.fc1(out)
        out = self.drop_nn1(torch.relu(self.fc1_bn(out)))
        out = self.fc2(out)
        out = self.drop_nn2(torch.relu(self.fc2_bn(out)))
        out = self.fc3(out)
        return out


In [3]:
def rcm_flankingOnly_large_windows_optuna(num_trial):
    
    ## specify different kmer length to get the training data of the rcm score for that kmer 
    ### just change this number to 10, 20, 40 and 80 to get the model performance for different kmer length

    study = optuna.create_study(direction='maximize')

    ### where to save the 3-fold CV validation acc based on the rcm score and mlp

    val_acc_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingOnly_large_windows/9000/val_acc_cv3'
    ### where to save the best model in the 3-fold CV 
    ### wehre to save the detailed optuna results
    optuna_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingOnly_large_windows/9000/optuna'
    
    
    BS_LS_coordinates_path = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/updated_data/BS_LS_coordinates_final.csv'
    hg19_seq_dict_json_path = '/home/wangc90/circRNA/circRNA_Data/hg19_seq/hg19_seq_dict.json'
    flanking_dict_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/'
    bs_ls_dataset = BS_LS_DataSet_Prep(BS_LS_coordinates_path=BS_LS_coordinates_path,
                                   hg19_seq_dict_json_path=hg19_seq_dict_json_path,
                                   flanking_dict_folder=flanking_dict_folder,
                                   flanking_junction_bps=100,
                                   flanking_intron_bps=5000,
                                   training_size=9000)


    ## generate the junction and flanking intron dict
    bs_ls_dataset.get_junction_flanking_intron_seq()

    train_key_1, _, test_keys = bs_ls_dataset.get_train_test_keys()

    rcm_scores_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/rcm_scores/'

    ### only use flanking_rcm_scores
    _, _, train_torch_flanking_rcm, _,\
    _, train_torch_labels = bs_ls_dataset.seq_to_tensor(data_keys=train_key_1, rcm_folder=rcm_scores_folder,\
                                                        is_rcm=True, is_upper_lower_concat=False)
    
#     print(train_torch_flanking_rcm.shape)

    RCM_kmer_Score_dataset = RCM_Score(flanking_only=True,
                                       flanking_rcm=train_torch_flanking_rcm, 
                                       upper_rcm=None,\
                                       lower_rcm=None, label=train_torch_labels)
    
    print(len(RCM_kmer_Score_dataset))
    
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=1, n_startup_trials=10),
                                direction='maximize')

    study.optimize(Objective_CV(cv=3, model= RCM_optuna_flanking, 
                                dataset=RCM_kmer_Score_dataset,
                                val_acc_folder=val_acc_folder), n_trials=num_trial, gc_after_trial=True)


    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    with open(optuna_folder+'/optuna.txt', 'a') as f:
        f.write("Study statistiqcs: \n")
        f.write(f"Number of finished trials: {len(study.trials)}\n")
        f.write(f"Number of pruned trials: {len(pruned_trials)}\n")
        f.write(f"Number of complete trials: {len(complete_trials)}\n")

        f.write("Best trial:\n")
        trial = study.best_trial
        f.write(f"Value: {trial.value}\n")
        f.write("Params:\n")
        for key, value in trial.params.items():
            f.write(f"{key}:{value}\n")

    df = study.trials_dataframe().drop(['state','datetime_start','datetime_complete','duration','number'], axis=1)
    df.to_csv(optuna_folder + '/optuna.csv', sep='\t', index=None)

In [None]:
rcm_flankingOnly_large_windows_optuna(num_trial=500)

[32m[I 2023-12-13 16:42:05,594][0m A new study created in memory with name: no-name-a60beb67-bbf0-45cb-b4f4-dc2f6222b9eb[0m


chr5|138837130|138837392|- has N in the extracted junctions, belongs to BS
There are 0 overlapped flanking sequence from BS and LS  
There are 7 repeated BS sequences
There are 2 repeated LS sequences


[32m[I 2023-12-13 16:44:04,633][0m A new study created in memory with name: no-name-9eff2982-c72a-4cc7-baea-e808e19b4213[0m


18000
fold 1, epoch 20, val loss 25.325473368167877 val accuracy 0.5195
fold 1, epoch 40, val loss 44.86837124824524 val accuracy 0.5277
fold 1, epoch 60, val loss 55.36450743675232 val accuracy 0.5223
fold 1, epoch 80, val loss 61.93777871131897 val accuracy 0.5253
fold 1, epoch 100, val loss 70.18387222290039 val accuracy 0.5248
fold 1, epoch 120, val loss 72.23770713806152 val accuracy 0.5243
fold 2, epoch 20, val loss 26.728899478912354 val accuracy 0.5225
fold 2, epoch 40, val loss 46.56070411205292 val accuracy 0.5212
fold 2, epoch 60, val loss 55.3813613653183 val accuracy 0.5258
fold 2, epoch 80, val loss 62.5520384311676 val accuracy 0.5212
fold 2, epoch 100, val loss 65.24964499473572 val accuracy 0.5258
fold 2, epoch 120, val loss 66.81811666488647 val accuracy 0.5245
fold 3, epoch 20, val loss 25.810437202453613 val accuracy 0.5107
fold 3, epoch 40, val loss 42.63780736923218 val accuracy 0.5217
fold 3, epoch 60, val loss 55.84415149688721 val accuracy 0.5295
fold 3, epoch 

[32m[I 2023-12-13 17:05:59,947][0m Trial 0 finished with value: 0.5251666666666667 and parameters: {'lr': 0.00013885148516962036, 'l2_lambda': 9.942895565068018e-07, 'batch_size': 256, 'epochs': 120, 'flanking_out_channel1': 512, 'flanking_out_channel2': 128, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0.2, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0.4}. Best is trial 0 with value: 0.5251666666666667.[0m


fold 3, epoch 120, val loss 70.23675560951233 val accuracy 0.5267
fold 1, epoch 20, val loss 73.17754864692688 val accuracy 0.5325
fold 2, epoch 20, val loss 73.04633712768555 val accuracy 0.5363
fold 3, epoch 20, val loss 74.28981858491898 val accuracy 0.5273


[32m[I 2023-12-13 17:13:49,015][0m Trial 1 finished with value: 0.5313333333333333 and parameters: {'lr': 1.6902387996402e-05, 'l2_lambda': 3.0991424910059367e-09, 'batch_size': 64, 'epochs': 30, 'flanking_out_channel1': 128, 'flanking_out_channel2': 512, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0.1, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0.2}. Best is trial 1 with value: 0.5313333333333333.[0m


fold 1, epoch 20, val loss 148.46243238449097 val accuracy 0.5218
fold 1, epoch 40, val loss 201.94193756580353 val accuracy 0.5183
fold 1, epoch 60, val loss 238.27638924121857 val accuracy 0.5208
fold 1, epoch 80, val loss 249.67528760433197 val accuracy 0.5173
fold 2, epoch 20, val loss 146.7046965956688 val accuracy 0.518
fold 2, epoch 40, val loss 193.6542900800705 val accuracy 0.529
fold 2, epoch 60, val loss 223.5519688129425 val accuracy 0.529
fold 2, epoch 80, val loss 235.89124810695648 val accuracy 0.5298
fold 3, epoch 20, val loss 136.86252123117447 val accuracy 0.5295
fold 3, epoch 40, val loss 186.5613133907318 val accuracy 0.5257
fold 3, epoch 60, val loss 219.98276329040527 val accuracy 0.5315
fold 3, epoch 80, val loss 241.83380925655365 val accuracy 0.5195


[32m[I 2023-12-13 17:39:41,200][0m Trial 2 finished with value: 0.5291 and parameters: {'lr': 8.013303007297937e-05, 'l2_lambda': 2.271194923290048e-08, 'batch_size': 64, 'epochs': 90, 'flanking_out_channel1': 128, 'flanking_out_channel2': 512, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0.2, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0.2}. Best is trial 1 with value: 0.5313333333333333.[0m


fold 1, epoch 20, val loss 286.5271054506302 val accuracy 0.5205
fold 1, epoch 40, val loss 413.2575385570526 val accuracy 0.5232
fold 1, epoch 60, val loss 443.5320693850517 val accuracy 0.5188
fold 1, epoch 80, val loss 443.90657448768616 val accuracy 0.5203
fold 2, epoch 20, val loss 299.06073665618896 val accuracy 0.5245
fold 2, epoch 40, val loss 406.36803072690964 val accuracy 0.5222
fold 2, epoch 60, val loss 440.4018065929413 val accuracy 0.5155
fold 2, epoch 80, val loss 467.8294583559036 val accuracy 0.5162
fold 3, epoch 20, val loss 297.39609003067017 val accuracy 0.5235
fold 3, epoch 40, val loss 394.76381665468216 val accuracy 0.5208
fold 3, epoch 60, val loss 421.4132002592087 val accuracy 0.5212
fold 3, epoch 80, val loss 458.8377900123596 val accuracy 0.5113


[32m[I 2023-12-13 18:15:37,253][0m Trial 3 finished with value: 0.5173333333333333 and parameters: {'lr': 0.00044340164090517063, 'l2_lambda': 2.861216452254345e-08, 'batch_size': 32, 'epochs': 90, 'flanking_out_channel1': 256, 'flanking_out_channel2': 128, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0.2, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0.2}. Best is trial 1 with value: 0.5313333333333333.[0m


fold 1, epoch 20, val loss 35.458916902542114 val accuracy 0.5108
fold 2, epoch 20, val loss 34.00640708208084 val accuracy 0.5085
fold 3, epoch 20, val loss 34.08370780944824 val accuracy 0.5252


[32m[I 2023-12-13 18:22:45,984][0m Trial 4 finished with value: 0.5143 and parameters: {'lr': 1.6496695246446215e-05, 'l2_lambda': 5.039979556252447e-08, 'batch_size': 128, 'epochs': 30, 'flanking_out_channel1': 256, 'flanking_out_channel2': 128, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0, 'fc2_out': 4, 'concat_dropout_rate_fc2': 0}. Best is trial 1 with value: 0.5313333333333333.[0m


fold 1, epoch 20, val loss 69.6079820394516 val accuracy 0.5307
fold 1, epoch 40, val loss 77.88380181789398 val accuracy 0.5277
fold 1, epoch 60, val loss 90.9084467291832 val accuracy 0.5223
fold 1, epoch 80, val loss 110.01918506622314 val accuracy 0.5193
fold 1, epoch 100, val loss 129.823666036129 val accuracy 0.5143
fold 1, epoch 120, val loss 149.65672492980957 val accuracy 0.5155
fold 2, epoch 20, val loss 67.26979798078537 val accuracy 0.532
fold 2, epoch 40, val loss 76.17858177423477 val accuracy 0.5213
fold 2, epoch 60, val loss 92.50261157751083 val accuracy 0.5165
fold 2, epoch 80, val loss 111.51938211917877 val accuracy 0.5212
fold 2, epoch 100, val loss 134.03442841768265 val accuracy 0.5173
fold 2, epoch 120, val loss 157.83999872207642 val accuracy 0.518
fold 3, epoch 20, val loss 66.83822113275528 val accuracy 0.5263
fold 3, epoch 40, val loss 73.96940529346466 val accuracy 0.524
fold 3, epoch 60, val loss 89.623763859272 val accuracy 0.5235
fold 3, epoch 80, val lo

[32m[I 2023-12-13 18:55:30,111][0m Trial 5 finished with value: 0.5181 and parameters: {'lr': 1.2168393039014852e-05, 'l2_lambda': 1.4399810201150328e-07, 'batch_size': 64, 'epochs': 120, 'flanking_out_channel1': 256, 'flanking_out_channel2': 256, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0, 'fc2_out': 16, 'concat_dropout_rate_fc2': 0.2}. Best is trial 1 with value: 0.5313333333333333.[0m


fold 3, epoch 120, val loss 149.57324892282486 val accuracy 0.5208
fold 1, epoch 20, val loss 9.795427739620209 val accuracy 0.5288
fold 1, epoch 40, val loss 12.127801060676575 val accuracy 0.5247
fold 1, epoch 60, val loss 14.13537609577179 val accuracy 0.5287
fold 1, epoch 80, val loss 16.041125774383545 val accuracy 0.5233
fold 1, epoch 100, val loss 17.119811415672302 val accuracy 0.5247
fold 1, epoch 120, val loss 18.69002377986908 val accuracy 0.5278
fold 2, epoch 20, val loss 9.843673825263977 val accuracy 0.5373
fold 2, epoch 40, val loss 12.60049593448639 val accuracy 0.5342
fold 2, epoch 60, val loss 14.357364535331726 val accuracy 0.5302
fold 2, epoch 80, val loss 16.022205591201782 val accuracy 0.533
fold 2, epoch 100, val loss 17.456456184387207 val accuracy 0.5312
fold 2, epoch 120, val loss 19.096523642539978 val accuracy 0.5352
fold 3, epoch 20, val loss 10.021587193012238 val accuracy 0.5202
fold 3, epoch 40, val loss 12.345686912536621 val accuracy 0.5208
fold 3, epo

[32m[I 2023-12-13 19:15:05,767][0m Trial 6 finished with value: 0.5287333333333334 and parameters: {'lr': 0.00010665256914566466, 'l2_lambda': 1.0918703019401434e-09, 'batch_size': 512, 'epochs': 120, 'flanking_out_channel1': 128, 'flanking_out_channel2': 256, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0.1, 'fc2_out': 16, 'concat_dropout_rate_fc2': 0.2}. Best is trial 1 with value: 0.5313333333333333.[0m


fold 3, epoch 120, val loss 19.64191722869873 val accuracy 0.5232
fold 1, epoch 20, val loss 72.61453652381897 val accuracy 0.516
fold 1, epoch 40, val loss 100.25596857070923 val accuracy 0.5225
fold 1, epoch 60, val loss 116.69084453582764 val accuracy 0.5323
fold 1, epoch 80, val loss 126.49225008487701 val accuracy 0.5285
fold 1, epoch 100, val loss 129.90722715854645 val accuracy 0.5275
fold 1, epoch 120, val loss 137.05706071853638 val accuracy 0.5287
fold 2, epoch 20, val loss 74.15760958194733 val accuracy 0.5282
