In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
import pandas as pd
import numpy as np
import os
import random
from collections import defaultdict, Counter
from itertools import combinations
import json
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold
import optuna
from torchmetrics.classification import F1Score
import pickle

import sys
### import Dataset prepartion and model training classes from BS_LS_scripts folder
sys.path.insert(1, '/home/wangc90/circRNA/circRNA_Data/BS_LS_scripts/')
from BS_LS_DataSet_3 import BS_LS_DataSet_Prep, RCM_Score
from BS_LS_Training_Base_models_1 import Objective, Objective_CV


In [2]:
class RCM_optuna_flanking(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the flanking introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_flanking, self).__init__()
        
        # convlayer 1
        self.out_channel1 = trial.suggest_categorical('flanking_out_channel1', [128, 256, 512])
#         self.out_channel1 = 128

#         kernel_size1 = 5

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=5, stride=5, padding=0)
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
        
        self.out_channel2 = trial.suggest_categorical('flanking_out_channel2', [128, 256, 512])
#         self.out_channel2 = 32
        
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
        
        self.conv2_out_dim = 10
        
        
        self.fc1_input_dim = self.conv2_out_dim * self.out_channel2 

#         print(f'fc1_input_dim: {self.fc1_input_dim}')
        
        
        self.fc1_out = trial.suggest_categorical('fc1_out', [128, 256, 512])
#         self.fc1_out = 512
    
        # add the rcm feature dimension here as well (5*5+2)*3+2 = 83
        self.fc1 = nn.Linear(self.fc1_input_dim, self.fc1_out)
        
        self.fc1_bn = nn.BatchNorm1d(self.fc1_out)

        dropout_rate_fc1 = trial.suggest_categorical("concat_dropout_rate_fc1",  [0, 0.1, 0.2, 0.4])
        self.drop_nn1 = nn.Dropout(p=dropout_rate_fc1)

        # fc layer2
        # use dimension output with nn.CrossEntropyLoss()
        self.fc2_out = trial.suggest_categorical('fc2_out', [4, 8, 16, 32])
#         self.fc2_out = 8
        self.fc2 = nn.Linear(self.fc1_out, self.fc2_out)

        self.fc2_bn = nn.BatchNorm1d(self.fc2_out)

        dropout_rate_fc2 = trial.suggest_categorical("concat_dropout_rate_fc2",[0, 0.1, 0.2, 0.4])
    
        self.drop_nn2 = nn.Dropout(p=dropout_rate_fc2)

        self.fc3 = nn.Linear(self.fc2_out, 2)
        

    def forward(self, x):
        
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
    
        out = self.fc1(out)
        out = self.drop_nn1(torch.relu(self.fc1_bn(out)))
        out = self.fc2(out)
        out = self.drop_nn2(torch.relu(self.fc2_bn(out)))
        out = self.fc3(out)
        return out


In [3]:
def rcm_flankingOnly_all_windows_optuna(num_trial):
    
    ## specify different kmer length to get the training data of the rcm score for that kmer 
    ### just change this number to 10, 20, 40 and 80 to get the model performance for different kmer length

    study = optuna.create_study(direction='maximize')

    ### where to save the 3-fold CV validation acc based on the rcm score and mlp

    val_acc_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingOnly_all_windows/8000/val_acc_cv3'
    ### where to save the best model in the 3-fold CV 
    ### wehre to save the detailed optuna results
    optuna_folder = f'/home/wangc90/circRNA/circRNA_Data/model_outputs/rcm_flankingOnly_all_windows/8000/optuna'
    
    
    BS_LS_coordinates_path = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/updated_data/BS_LS_coordinates_final.csv'
    hg19_seq_dict_json_path = '/home/wangc90/circRNA/circRNA_Data/hg19_seq/hg19_seq_dict.json'
    flanking_dict_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/'
    bs_ls_dataset = BS_LS_DataSet_Prep(BS_LS_coordinates_path=BS_LS_coordinates_path,
                                   hg19_seq_dict_json_path=hg19_seq_dict_json_path,
                                   flanking_dict_folder=flanking_dict_folder,
                                   flanking_junction_bps=100,
                                   flanking_intron_bps=5000,
                                   training_size=8000)


    ## generate the junction and flanking intron dict
    bs_ls_dataset.get_junction_flanking_intron_seq()

    train_key_1, _, test_keys = bs_ls_dataset.get_train_test_keys()

    rcm_scores_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/rcm_scores/'

    ### only use flanking_rcm_scores
    _, _, train_torch_flanking_rcm, _,\
    _, train_torch_labels = bs_ls_dataset.seq_to_tensor(data_keys=train_key_1, rcm_folder=rcm_scores_folder,\
                                                        is_rcm=True, is_upper_lower_concat=False)
    
#     print(train_torch_flanking_rcm.shape)

    RCM_kmer_Score_dataset = RCM_Score(flanking_only=True,
                                       flanking_rcm=train_torch_flanking_rcm, 
                                       upper_rcm=None,\
                                       lower_rcm=None, label=train_torch_labels)
    
    print(len(RCM_kmer_Score_dataset))
    
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=1, n_startup_trials=10),
                                direction='maximize')

    study.optimize(Objective_CV(cv=3, model= RCM_optuna_flanking, 
                                dataset=RCM_kmer_Score_dataset,
                                val_acc_folder=val_acc_folder), n_trials=num_trial, gc_after_trial=True)


    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    with open(optuna_folder+'/optuna.txt', 'a') as f:
        f.write("Study statistiqcs: \n")
        f.write(f"Number of finished trials: {len(study.trials)}\n")
        f.write(f"Number of pruned trials: {len(pruned_trials)}\n")
        f.write(f"Number of complete trials: {len(complete_trials)}\n")

        f.write("Best trial:\n")
        trial = study.best_trial
        f.write(f"Value: {trial.value}\n")
        f.write("Params:\n")
        for key, value in trial.params.items():
            f.write(f"{key}:{value}\n")

    df = study.trials_dataframe().drop(['state','datetime_start','datetime_complete','duration','number'], axis=1)
    df.to_csv(optuna_folder + '/optuna.csv', sep='\t', index=None)

In [None]:
rcm_flankingOnly_all_windows_optuna(num_trial=500)

[32m[I 2023-11-22 19:18:00,230][0m A new study created in memory with name: no-name-f1275dcc-d170-4872-9c99-c34af0eb342d[0m


chr5|138837130|138837392|- has N in the extracted junctions, belongs to BS
There are 0 overlapped flanking sequence from BS and LS  
There are 7 repeated BS sequences
There are 2 repeated LS sequences


[32m[I 2023-11-22 19:20:17,805][0m A new study created in memory with name: no-name-726d3ca2-b2fa-4a93-bec3-7316e065429c[0m


16000
fold 1, epoch 20, val loss 15.314567744731903 val accuracy 0.527
fold 1, epoch 40, val loss 17.06819522380829 val accuracy 0.5238
fold 1, epoch 60, val loss 19.26196539402008 val accuracy 0.5157
fold 2, epoch 20, val loss 14.981064796447754 val accuracy 0.543
fold 2, epoch 40, val loss 16.4035427570343 val accuracy 0.5357
fold 2, epoch 60, val loss 18.491921305656433 val accuracy 0.5355
fold 3, epoch 20, val loss 15.52218097448349 val accuracy 0.525
fold 3, epoch 40, val loss 17.166387975215912 val accuracy 0.5211


[32m[I 2023-11-22 19:32:49,119][0m Trial 0 finished with value: 0.5241000000000001 and parameters: {'lr': 1.9265704241842948e-05, 'l2_lambda': 8.955787810162911e-09, 'batch_size': 256, 'epochs': 60, 'flanking_out_channel1': 128, 'flanking_out_channel2': 256, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0.1}. Best is trial 0 with value: 0.5241000000000001.[0m


fold 3, epoch 60, val loss 19.07050132751465 val accuracy 0.5211
fold 1, epoch 20, val loss 30.48302513360977 val accuracy 0.5315
fold 1, epoch 40, val loss 36.47945749759674 val accuracy 0.5289
fold 1, epoch 60, val loss 44.36103230714798 val accuracy 0.5247
fold 1, epoch 80, val loss 52.48164635896683 val accuracy 0.5253
fold 1, epoch 100, val loss 63.488852858543396 val accuracy 0.5366
fold 1, epoch 120, val loss 71.47134602069855 val accuracy 0.5315
fold 2, epoch 20, val loss 31.023080110549927 val accuracy 0.5273
fold 2, epoch 40, val loss 37.12838101387024 val accuracy 0.5205
fold 2, epoch 60, val loss 46.69512206315994 val accuracy 0.5245
fold 2, epoch 80, val loss 55.65063142776489 val accuracy 0.5185
fold 2, epoch 100, val loss 65.82132506370544 val accuracy 0.5226
fold 2, epoch 120, val loss 75.85387432575226 val accuracy 0.5194
fold 3, epoch 20, val loss 30.560830950737 val accuracy 0.5314
fold 3, epoch 40, val loss 36.29194217920303 val accuracy 0.5365
fold 3, epoch 60, val

[32m[I 2023-11-22 20:02:56,309][0m Trial 1 finished with value: 0.5244333333333333 and parameters: {'lr': 3.162629108526024e-05, 'l2_lambda': 6.237840296069283e-09, 'batch_size': 128, 'epochs': 120, 'flanking_out_channel1': 256, 'flanking_out_channel2': 128, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0.1, 'fc2_out': 16, 'concat_dropout_rate_fc2': 0.4}. Best is trial 1 with value: 0.5244333333333333.[0m


fold 3, epoch 120, val loss 77.5265040397644 val accuracy 0.5224
fold 1, epoch 20, val loss 29.173382997512817 val accuracy 0.5366
fold 1, epoch 40, val loss 29.428374886512756 val accuracy 0.5313
fold 1, epoch 60, val loss 29.851643085479736 val accuracy 0.533
fold 1, epoch 80, val loss 30.24014574289322 val accuracy 0.5298
fold 2, epoch 20, val loss 28.786368429660797 val accuracy 0.5539
fold 2, epoch 40, val loss 29.153178572654724 val accuracy 0.5528
fold 2, epoch 60, val loss 29.457596361637115 val accuracy 0.5404
fold 2, epoch 80, val loss 30.04780799150467 val accuracy 0.531
fold 3, epoch 20, val loss 29.531602025032043 val accuracy 0.5436
fold 3, epoch 40, val loss 30.28730034828186 val accuracy 0.5419
fold 3, epoch 60, val loss 31.07867604494095 val accuracy 0.5301
fold 3, epoch 80, val loss 31.990120887756348 val accuracy 0.5383


[32m[I 2023-11-22 20:26:48,184][0m Trial 2 finished with value: 0.5308 and parameters: {'lr': 1.0305496778192561e-05, 'l2_lambda': 5.4403072827698657e-08, 'batch_size': 128, 'epochs': 90, 'flanking_out_channel1': 512, 'flanking_out_channel2': 256, 'fc1_out': 256, 'concat_dropout_rate_fc1': 0.4, 'fc2_out': 4, 'concat_dropout_rate_fc2': 0}. Best is trial 2 with value: 0.5308.[0m


fold 1, epoch 20, val loss 7.573329985141754 val accuracy 0.536
fold 1, epoch 40, val loss 7.534786939620972 val accuracy 0.5435
fold 1, epoch 60, val loss 7.557947039604187 val accuracy 0.5486
fold 1, epoch 80, val loss 7.5750961899757385 val accuracy 0.5487
fold 2, epoch 20, val loss 7.565214157104492 val accuracy 0.5498
fold 2, epoch 40, val loss 7.5745890736579895 val accuracy 0.5438
fold 2, epoch 60, val loss 7.604867041110992 val accuracy 0.5451
fold 2, epoch 80, val loss 7.650249600410461 val accuracy 0.5423
fold 3, epoch 20, val loss 7.72695779800415 val accuracy 0.5367
fold 3, epoch 40, val loss 7.700622498989105 val accuracy 0.5447
fold 3, epoch 60, val loss 7.71911346912384 val accuracy 0.5462
fold 3, epoch 80, val loss 7.755358874797821 val accuracy 0.5447


[32m[I 2023-11-22 20:45:49,034][0m Trial 3 finished with value: 0.5428 and parameters: {'lr': 1.0858091481537958e-05, 'l2_lambda': 1.6153619299670391e-07, 'batch_size': 512, 'epochs': 90, 'flanking_out_channel1': 512, 'flanking_out_channel2': 128, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0.4, 'fc2_out': 8, 'concat_dropout_rate_fc2': 0}. Best is trial 3 with value: 0.5428.[0m


fold 1, epoch 20, val loss 22.287332773208618 val accuracy 0.5184
fold 1, epoch 40, val loss 24.59628665447235 val accuracy 0.5304
fold 1, epoch 60, val loss 28.270833015441895 val accuracy 0.5328
fold 1, epoch 80, val loss 29.754035472869873 val accuracy 0.5298
fold 1, epoch 100, val loss 29.76432704925537 val accuracy 0.5304
fold 1, epoch 120, val loss 33.91359210014343 val accuracy 0.5285
fold 2, epoch 20, val loss 15.925340294837952 val accuracy 0.5232
fold 2, epoch 40, val loss 22.992340326309204 val accuracy 0.5262
fold 2, epoch 60, val loss 29.20237135887146 val accuracy 0.5265
fold 2, epoch 80, val loss 31.133686065673828 val accuracy 0.5262
fold 2, epoch 100, val loss 36.09417510032654 val accuracy 0.5198
fold 2, epoch 120, val loss 34.86344861984253 val accuracy 0.5262
fold 3, epoch 20, val loss 14.474853754043579 val accuracy 0.5245
fold 3, epoch 40, val loss 25.098841905593872 val accuracy 0.5286
fold 3, epoch 60, val loss 27.54655909538269 val accuracy 0.5173
fold 3, epoch

[32m[I 2023-11-22 21:11:16,257][0m Trial 4 finished with value: 0.5254 and parameters: {'lr': 0.00078712419797665, 'l2_lambda': 5.0294097562421123e-08, 'batch_size': 512, 'epochs': 120, 'flanking_out_channel1': 512, 'flanking_out_channel2': 128, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0.4, 'fc2_out': 16, 'concat_dropout_rate_fc2': 0.2}. Best is trial 3 with value: 0.5428.[0m


fold 3, epoch 120, val loss 34.08187294006348 val accuracy 0.5215
fold 1, epoch 20, val loss 261.75162267684937 val accuracy 0.5201
fold 2, epoch 20, val loss 262.7395570278168 val accuracy 0.5254
fold 3, epoch 20, val loss 262.09327375888824 val accuracy 0.5286


[32m[I 2023-11-22 21:25:25,565][0m Trial 5 finished with value: 0.5277999999999999 and parameters: {'lr': 9.247506560787085e-05, 'l2_lambda': 5.3486120091749385e-09, 'batch_size': 32, 'epochs': 30, 'flanking_out_channel1': 128, 'flanking_out_channel2': 256, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0.2, 'fc2_out': 32, 'concat_dropout_rate_fc2': 0}. Best is trial 3 with value: 0.5428.[0m


fold 1, epoch 20, val loss 14.757123172283173 val accuracy 0.5418
fold 1, epoch 40, val loss 15.010938823223114 val accuracy 0.542
fold 1, epoch 60, val loss 15.479822993278503 val accuracy 0.5356
fold 1, epoch 80, val loss 15.705274164676666 val accuracy 0.5317
fold 1, epoch 100, val loss 16.09922903776169 val accuracy 0.5319
fold 1, epoch 120, val loss 16.342253804206848 val accuracy 0.5302
fold 2, epoch 20, val loss 15.50257670879364 val accuracy 0.5292
fold 2, epoch 40, val loss 15.608249247074127 val accuracy 0.5395
fold 2, epoch 60, val loss 16.194899082183838 val accuracy 0.5374
fold 2, epoch 80, val loss 15.963672757148743 val accuracy 0.5389
fold 2, epoch 100, val loss 16.605023205280304 val accuracy 0.535
fold 2, epoch 120, val loss 16.801768720149994 val accuracy 0.5335
fold 3, epoch 20, val loss 14.637836635112762 val accuracy 0.5365
fold 3, epoch 40, val loss 14.758317291736603 val accuracy 0.5338
fold 3, epoch 60, val loss 14.94555813074112 val accuracy 0.5344
fold 3, epo

[32m[I 2023-11-22 21:52:41,521][0m Trial 6 finished with value: 0.5324666666666666 and parameters: {'lr': 1.2953565908545368e-05, 'l2_lambda': 7.984859216305832e-07, 'batch_size': 256, 'epochs': 120, 'flanking_out_channel1': 256, 'flanking_out_channel2': 512, 'fc1_out': 512, 'concat_dropout_rate_fc1': 0.4, 'fc2_out': 8, 'concat_dropout_rate_fc2': 0}. Best is trial 3 with value: 0.5428.[0m


fold 3, epoch 120, val loss 15.452976107597351 val accuracy 0.5337
fold 1, epoch 20, val loss 7.867133677005768 val accuracy 0.5247
fold 1, epoch 40, val loss 8.013933598995209 val accuracy 0.5257
fold 1, epoch 60, val loss 8.19753623008728 val accuracy 0.5191
fold 2, epoch 20, val loss 7.821981966495514 val accuracy 0.5119
fold 2, epoch 40, val loss 7.985138773918152 val accuracy 0.5153
fold 2, epoch 60, val loss 8.184622168540955 val accuracy 0.5194
fold 3, epoch 20, val loss 7.774190604686737 val accuracy 0.5235
fold 3, epoch 40, val loss 7.867578506469727 val accuracy 0.535


[32m[I 2023-11-22 22:04:25,569][0m Trial 7 finished with value: 0.5235666666666666 and parameters: {'lr': 1.8695847148568695e-05, 'l2_lambda': 6.453259344049084e-07, 'batch_size': 512, 'epochs': 60, 'flanking_out_channel1': 128, 'flanking_out_channel2': 128, 'fc1_out': 128, 'concat_dropout_rate_fc1': 0, 'fc2_out': 8, 'concat_dropout_rate_fc2': 0.1}. Best is trial 3 with value: 0.5428.[0m


fold 3, epoch 60, val loss 7.99491560459137 val accuracy 0.5322
