In [2]:
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
import pandas as pd
import numpy as np
import os
import random
from collections import defaultdict, Counter
from itertools import combinations
import json
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold
import optuna
from torchmetrics.classification import F1Score
import pickle

In [3]:
class BS_LS_DataSet_Prep():
    def __init__(self, BS_LS_coordinates_path, hg19_seq_dict_json_path,
                 flanking_dict_folder,
                 flanking_junction_bps=None,
                 flanking_intron_bps=None,
                 training_size=None):

        self.BS_LS_coordinates_path = BS_LS_coordinates_path
        self.flanking_junction_bps = flanking_junction_bps
        self.flanking_intron_bps = flanking_intron_bps
        self.flanking_dict_folder = flanking_dict_folder
        ### bring in the hg19 genomic sequences
        with open(hg19_seq_dict_json_path) as f:
            self.hg19_seq_dict = json.load(f)
        self.junction_seq_json_name = f'BS_LS_junction_seq_{self.flanking_junction_bps}_bps'
        self.flanking_seq_json_name = f'BS_LS_intronic_flanking_seq_{self.flanking_intron_bps}_bps'
        self.training_size = training_size

    def reverse_complement(self, input_seq):
        '''
        This function take a sequence and returns the reverse complementary sequence
        '''
        complement_dict = {'A': 'T', 'G': 'C', 'T': 'A', 'C': 'G', 'N': 'N'}
        out_seq = ''.join([complement_dict[base] for base in list(input_seq)[::-1]])
        return out_seq

    def get_junction_flanking_intron_seq(self):
        '''
        :param BS_LS_coordinates_path:
        :param flanking_bps:
        :return: two dictionaries
        1) junction_seq dictionary that stores the upper_intron, upper_exon, lower_exon, and lower_intron
        sequences with 50bps (default) each for a given LS or BS site with the key:
        chr|up_end|start|end|down_start|strand

        2) intronic_flanking_seq dictionary stores the flanking intronic sequence for the exon pairs with the key:
        chr|start|end|strand
        '''

        junction_seq = {}
        intronic_flanking_seq = {}

        ## retain the BS exons that have no valid boundaries
        BS_LS_coordinates_df = pd.read_csv(self.BS_LS_coordinates_path, sep='\t')
        for _, row in BS_LS_coordinates_df.iterrows():
            chrom, strand, start, end, label = row['chr'], row['strand'], int(row['start']), int(row['end']),\
                                                                   row['Splicing_type']
            # this key is unique for each instance in the BS_LS dataframe
            key = '|'.join([chrom, str(start), str(end), strand])

            # get the corresponding chromosome DNA sequence
            dna_seq = self.hg19_seq_dict[chrom]
            # extract the spliced genomic sequence assuming the positive strand
            spliced_seq_P = dna_seq[start: end].upper()
            # extract the upper_intron junction seq assuming the positive strand

            ### use 200 introns here by addition of 100bps, remove 100 to get the original data
            # upper_intron_P = dna_seq[start - self.flanking_bps-100: start].upper()
            upper_intron_P = dna_seq[start - self.flanking_junction_bps: start].upper()

            # extract the lower_intron junction seq assuming the positive strand
            # lower_intron_P = dna_seq[end: end + self.flanking_bps+100].upper()
            lower_intron_P = dna_seq[end: end + self.flanking_junction_bps].upper()

            # skip the rows that have 'Ns' in the upper_intron or lower_intron
            if 'N' in upper_intron_P or 'N' in lower_intron_P:
                print(f'{key} has N in the extracted junctions, belongs to {label}')
                continue

            # get the upper flanking sequence for the positive strand: instead of
            # using the up_end position use the start - 1000 as the starting point

            U_flanking_seq_P = dna_seq[start-self.flanking_intron_bps:start].upper()

            # get the lower flanking sequence using down_start + 1000 as the ending point
            L_flanking_seq_P = dna_seq[end:end+self.flanking_intron_bps].upper()

            # consider if the strand is -
            if strand == '-':
                ### get the junction sequence
                spliced_seq_N = self.reverse_complement(spliced_seq_P)
                upper_exon_N = spliced_seq_N[:self.flanking_junction_bps]
                lower_exon_N = spliced_seq_N[-self.flanking_junction_bps:]

                upper_intron_N = self.reverse_complement(lower_intron_P)
                lower_intron_N = self.reverse_complement(upper_intron_P)

                ### get the flanking intronic sequence
                U_flanking_seq = self.reverse_complement(L_flanking_seq_P)
                L_flanking_seq = self.reverse_complement(U_flanking_seq_P)

                intronic_flanking_seq[key] = {'U_flanking_seq': U_flanking_seq,
                                              'L_flanking_seq': L_flanking_seq,
                                              'label': label}

                junction_seq[key] = {'spliced_seq': spliced_seq_N,
                                     'upper_intron_' + str(self.flanking_junction_bps): upper_intron_N,
                                     'upper_exon_' + str(self.flanking_junction_bps): upper_exon_N,
                                     'lower_exon_' + str(self.flanking_junction_bps): lower_exon_N,
                                     'lower_intron_' + str(self.flanking_junction_bps): lower_intron_N,
                                     'label': label}

            else:
                upper_exon_P = spliced_seq_P[:self.flanking_junction_bps]

                lower_exon_P = spliced_seq_P[-self.flanking_junction_bps:]

                junction_seq[key] = {'spliced_seq': spliced_seq_P,
                                     'upper_intron_' + str(self.flanking_junction_bps): upper_intron_P,
                                     'upper_exon_' + str(self.flanking_junction_bps): upper_exon_P,
                                     'lower_exon_' + str(self.flanking_junction_bps): lower_exon_P,
                                     'lower_intron_' + str(self.flanking_junction_bps): lower_intron_P,
                                     'label': label}

                intronic_flanking_seq[key] = {'U_flanking_seq': U_flanking_seq_P,
                                              'L_flanking_seq': L_flanking_seq_P,
                                              'label': label}

        ### remove the repeative sequence from the junction_seq dictionary and then
        ### remove the overlapping sequence between BS and LS
        BS_junction_seq_dict = {}
        LS_junction_seq_dict = {}

        for i, j in junction_seq.items():
            if j['label'] == 'BS':
                BS_junction_seq_dict[i] = j['upper_intron_' + str(self.flanking_junction_bps)] + \
                                          j['upper_exon_' + str(self.flanking_junction_bps)] + \
                                          j['lower_exon_' + str(self.flanking_junction_bps)] + \
                                          j['lower_intron_' + str(self.flanking_junction_bps)]
            else:
                LS_junction_seq_dict[i] = j['upper_intron_' + str(self.flanking_junction_bps)] + \
                                          j['upper_exon_' + str(self.flanking_junction_bps)] + \
                                          j['lower_exon_' + str(self.flanking_junction_bps)] + \
                                          j['lower_intron_' + str(self.flanking_junction_bps)]

        # first get the 43 overlapping junction sequences and then filter the two dict from these sequences
        overlap_junction_seqs = set(BS_junction_seq_dict.values()).intersection(LS_junction_seq_dict.values())

        # remove these 43 overlapping junction sequences from BS_junction_seq_dict
        BS_junction_seq_dict_wo_overlap = {key: value for key, value in BS_junction_seq_dict.items() if
                                           value not in overlap_junction_seqs}

        # remove the duplicated junction sequences from BS by using the value as the key and then reverse it
        BS_tem_dict = {value: key for key, value in BS_junction_seq_dict_wo_overlap.items()}
        BS_res_dict = {value: key for key, value in BS_tem_dict.items()}

        # remove these 43 overlapping junction sequences from LS_junction_seq_dict
        LS_junction_seq_dict_wo_overlap = {key: value for key, value in LS_junction_seq_dict.items() if
                                           value not in overlap_junction_seqs}

        # remove the duplicated junction sequences from LS by using the value as the key and then reverse it
        LS_tem_dict = {value: key for key, value in LS_junction_seq_dict_wo_overlap.items()}
        LS_res_dict = {value: key for key, value in LS_tem_dict.items()}

        # merge two dicts
        BS_LS_res_dict = {**BS_res_dict, **LS_res_dict}

        # filter the junction_seq with the new keys from BS_LS_res_dict
        junction_seq_final = {key: junction_seq[key] for key in BS_LS_res_dict.keys()}

        # filter the intronic_flanking_seq with the new keys from BS_LS_res_dict
        intronic_flanking_seq_final = {key: intronic_flanking_seq[key] for key in BS_LS_res_dict.keys()}

        ## save the junction_seq and intronic_flanking_seq to json on the harddrive

        with open(f'{self.flanking_dict_folder}{self.junction_seq_json_name}.json', 'w') as f:
            json.dump(junction_seq_final, f)

        with open(f'{self.flanking_dict_folder}{self.flanking_seq_json_name}.json', 'w') as f:
            json.dump(intronic_flanking_seq_final, f)

    #         return junction_seq, intronic_flanking_seq

    def get_train_test_keys(self):

        if not os.path.exists(f'{self.flanking_dict_folder}{self.junction_seq_json_name}.json'):
            # invoke the function to write the BS_LS_junction_seq and BS_LS_intronic_flanking_seq to the drive
            self.get_junction_flanking_intron_seq()
            with open(f'{self.flanking_dict_folder}{self.junction_seq_json_name}.json') as f:
                BS_LS_junction_seq_dict = json.load(f)
        else:
            # read the BS_LS_junction_seq from harddrive instead of calling the function
            with open(f'{self.flanking_dict_folder}{self.junction_seq_json_name}.json') as f:
                BS_LS_junction_seq_dict = json.load(f)

        BS_exon_key_list = [key for key, value in BS_LS_junction_seq_dict.items() if value['label'] == 'BS']
        LS_exon_key_list = [key for key, value in BS_LS_junction_seq_dict.items() if value['label'] == 'LS']

        np.random.seed(42)
        # randomly select 12000 from both BS and LS key list and combine them
        BS_exon_train_keys = np.random.choice(BS_exon_key_list, size=self.training_size, replace=False)
        LS_exon_train_keys = np.random.choice(LS_exon_key_list, size=self.training_size, replace=False)
        BS_LS_exon_train_keys = np.concatenate([BS_exon_train_keys, LS_exon_train_keys])

        # select the remaining keys from BS and LS key list as the test keys
        BS_exon_test_keys = set(BS_exon_key_list).difference(BS_exon_train_keys)
        LS_exon_test_keys = set(LS_exon_key_list).difference(LS_exon_train_keys)
        BS_LS_exon_test_keys = np.array(list(BS_exon_test_keys.union(LS_exon_test_keys)))

        return BS_LS_exon_train_keys, BS_LS_exon_test_keys

    def seq_to_matrix(self, input_seq):
        '''
            This function takes a DNA sequence and return a one-hot encoded matrix of 4 X N (length of input_seq)
        '''
        row_index = {'A': 0, 'G': 1, 'C': 2, 'T': 3}  # should exclude the 'Ns' in the input sequence

        # initialize the 4 X N 0 matrix:
        input_mat = np.zeros((4, len(input_seq)))

        for col_index, base in enumerate(input_seq):
            input_mat[row_index[base]][col_index] = 1
        return input_mat
    
    #### create all sequence features, rcm_features and a2i features silmaltineous here
    def seq_to_tensor(self, data_keys, rcm_folder, is_rcm=None, is_upper_lower_concat=None):
        '''
        :param data_keys:
        :param rcm_folder
        :param is_rcm: boolean value indicate if rcm features is genearated or not
        :param is_upper_lower_concat: boolean value indicates if upper and lower junction is concatenate or not
        :return: concatenated 2-d data for upper and lower seq for all the keys in the data_keys
        '''
        ### first get the BS_LS_junction_seq_dict;
        if not os.path.exists(f'{self.flanking_dict_folder}{self.junction_seq_json_name}.json'):
            self.get_junction_flanking_intron_seq()
            with open(f'{self.flanking_dict_folder}{self.junction_seq_json_name}.json') as f:
                BS_LS_junction_seq_dict = json.load(f)
        else:
            with open(f'{self.flanking_dict_folder}{self.junction_seq_json_name}.json') as f:
                BS_LS_junction_seq_dict = json.load(f)

        ### then get the BS_LS_flanking_seq_dict;
        if not os.path.exists(f'{self.flanking_dict_folder}{self.flanking_seq_json_name}.json'):
            self.get_junction_flanking_intron_seq()
            with open(f'{self.flanking_dict_folder}{self.flanking_seq_json_name}.json') as f:
                BS_LS_flanking_seq_dict = json.load(f)
        else:
            with open(f'{self.flanking_dict_folder}{self.flanking_seq_json_name}.json') as f:
                BS_LS_flanking_seq_dict = json.load(f)

        ### these two dictionary use the same keys
        ### list to store the concatenated upper and lower sequence one-hot encoding
        if is_upper_lower_concat:
            all_torch_feature_list = []
            
        ### list to store the upper and lower sequence one-hot encoding separately 
        else:
            all_torch_upper_feature_list = []
            all_torch_lower_feature_list = []
            
        ### list to store rcm features if is_rcm true ####using for loops!!!!!!
        if is_rcm:
            ### tri-cnn for flanking, upper and lower separately
            flanking_rcm_scores = []
            upper_rcm_scores = []
            lower_rcm_scores = []
            
            
            flanking_rcm_dict_list = []
            upper_rcm_dict_list = []
            lower_rcm_dict_list = []
            
            seed_len_list = [5, 7, 9, 11, 13]
            
#             for flanking_intron_len in [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000,
#                                         1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800,
#                                         1900, 2000]:
            for flanking_intron_len in [100, 500, 1000,1500, 2000]:
                
                for seed_len in seed_len_list:
            
                    with open(os.path.join(rcm_scores_folder, f'to_519_rcm_flanking_{flanking_intron_len}_bps_introns_{seed_len}mer.json')) as f:
                        flanking_rcm_dict = json.load(f)
                        flanking_rcm_dict_list.append(flanking_rcm_dict)
                    with open(os.path.join(rcm_scores_folder, f'to_519_rcm_upper_{flanking_intron_len}_bps_introns_{seed_len}mer.json')) as f:
                        upper_rcm_dict = json.load(f)
                        upper_rcm_dict_list.append(upper_rcm_dict)
                    with open(os.path.join(rcm_scores_folder, f'to_519_rcm_lower_{flanking_intron_len}_bps_introns_{seed_len}mer.json')) as f:
                        lower_rcm_dict = json.load(f)
                        lower_rcm_dict_list.append(lower_rcm_dict)
  
            
        ### list to store the corresponding labels: LS or BS
        all_label_list = []


        for key in data_keys:
            ### get the rcm features from the harddrive
            ### should produce the RCM features simultaneously with the seq feture for the same data point

            flanking_seqs = BS_LS_flanking_seq_dict[key]

            ## construct the rcm feature if is_rcm is True
            if is_rcm:
                
                flanking_rcm_kmer_list = [np.log(np.array(flanking_rcm[key][0]).reshape(5,5)+1) for flanking_rcm in flanking_rcm_dict_list]
                flanking_rcm_kmers = torch.from_numpy(np.concatenate(flanking_rcm_kmer_list, axis=1)).to(torch.float32)
                          
                upper_rcm_kmer_list = [np.log(np.array(upper_rcm[key][0]).reshape(5,5)+1) for upper_rcm in upper_rcm_dict_list]
                upper_rcm_kmers = torch.from_numpy(np.concatenate(upper_rcm_kmer_list, axis=1)).to(torch.float32)
                          
                lower_rcm_kmer_list = [np.log(np.array(lower_rcm[key][0]).reshape(5,5)+1) for lower_rcm in lower_rcm_dict_list]
                lower_rcm_kmers = torch.from_numpy(np.concatenate(lower_rcm_kmer_list, axis=1)).to(torch.float32)
                          
                
#                 rcm_values_concate = torch.from_numpy(np.stack([flanking_rcm_kmers,
#                                                             upper_rcm_kmers,
#                                                             lower_rcm_kmers], axis=0)).to(torch.float32)
                
                flanking_rcm_scores.append(flanking_rcm_kmers)
                upper_rcm_scores.append(upper_rcm_kmers)
                lower_rcm_scores.append(lower_rcm_kmers)
                
                
            ### working with the junction_seq starting here
            value = BS_LS_junction_seq_dict[key]
            # extract the poisitve (BS: as 1) and negative label (LS: as 0)

            ## make sure the label are the same for the same key and append the labe to the list
            label = value['label']
            assert label == flanking_seqs['label'], f"Same sequence key {key} with different labels"

            if label == 'BS':
                label = 1
            else:
                label = 0
                
            all_label_list.append(label)

            # concatenate the upper seq together and lower seq together for two separate CNN to process
            concatenated_upper_seq = value['upper_intron_{}'.format(str(self.flanking_junction_bps))] + \
                               value['upper_exon_{}'.format(str(self.flanking_junction_bps))]
            concatenated_lower_seq = value['lower_exon_{}'.format(str(self.flanking_junction_bps))] + \
                               value['lower_intron_{}'.format(str(self.flanking_junction_bps))]
            
            ### test whether want to concatenate the upper seq and lower seq together
            if is_upper_lower_concat:
                
                upper_lower_concat_seq = concatenated_upper_seq + concatenated_lower_seq
                upper_lower_concat_mat = self.seq_to_matrix(upper_lower_concat_seq)
                individual_upper_lower_torch = torch.from_numpy(upper_lower_concat_mat).to(torch.float32)
                all_torch_feature_list.append(individual_upper_lower_torch)
                
            else:
                
                individual_upper_mat = self.seq_to_matrix(concatenated_upper_seq)
                individual_lower_mat = self.seq_to_matrix(concatenated_lower_seq)

                # convert individual instance to torch
                individual_upper_torch = torch.from_numpy(individual_upper_mat).to(torch.float32)
                individual_lower_torch = torch.from_numpy(individual_lower_mat).to(torch.float32)

                all_torch_upper_feature_list.append(individual_upper_torch)
                all_torch_lower_feature_list.append(individual_lower_torch)
                
                
        all_torch_label = torch.tensor(all_label_list, dtype=torch.float32) #.view(-1, 1)
        
        if is_rcm:
            all_torch_flanking_rcm_features = torch.stack(flanking_rcm_scores, dim=0)
            all_torch_upper_rcm_features = torch.stack(upper_rcm_scores, dim=0)
            all_torch_lower_rcm_features = torch.stack(lower_rcm_scores, dim=0)
            
            
#             all_torch_rcm_feature = torch.stack(rcm_scores, dim=0)
            
        if is_upper_lower_concat:
            
            all_torch_feature = torch.stack(all_torch_feature_list, dim=0)
            
        else:
            all_torch_upper_feature = torch.stack(all_torch_upper_feature_list, dim=0)
            all_torch_lower_feature = torch.stack(all_torch_lower_feature_list, dim=0)
            
        ### return the tensor based on several requirement
        ## return only the upper, lower and rcm / upper lower, a2i / or both/ or just upper lower /or just concate
        
        if is_upper_lower_concat and not is_rcm:
            
            return all_torch_feature, all_torch_label
        
        if is_upper_lower_concat and is_rcm:
            return all_torch_feature, all_torch_flanking_rcm_features, all_torch_upper_rcm_features, all_torch_lower_rcm_features, all_torch_label
        
        if not is_upper_lower_concat and not is_rcm:
            return all_torch_upper_feature, all_torch_lower_feature, all_torch_label
        
        if not is_upper_lower_concat and is_rcm:
            return all_torch_upper_feature, all_torch_lower_feature, all_torch_flanking_rcm_features, all_torch_upper_rcm_features, all_torch_lower_rcm_features, all_torch_label
            

In [4]:
class BS_LS_upper_lower_concat_rcm(Dataset):
    def __init__(self, include_rcm, seq_upper_lower_feature, flanking_rcm, upper_rcm, lower_rcm, label):
        # construction of the map-style datasets
        # data loading
        self.include_rcm = include_rcm
        
        if self.include_rcm:
            self.x1 = seq_upper_lower_feature

            self.x2 = flanking_rcm
            self.x3 = upper_rcm
            self.x4 = lower_rcm
            
        else:
            self.x1 = seq_upper_lower_feature
            
        self.y = label

        self.n_samples = seq_upper_lower_feature.size()[0]

    def __getitem__(self, index):
        # dataset[0]
        if self.include_rcm:
            return self.x1[index], self.x2[index], self.x3[index], self.x4[index], self.y[index]
        else:
            return self.x1[index], self.y[index]
    def __len__(self):
        # len(dataset)
        return self.n_samples

In [27]:
### Model 1 input sequence 4 X 400 with 1 or 2CNN layer
### remove the dropout for CNN layers

class Model1_optuna(nn.Module):
    '''
        This model take in input sequence 4 X 400 with 1 CNN layer
    '''
    def __init__(self, trial, cnn_num):
        
        self.cnn_num = cnn_num
        
        super(Model1_optuna, self).__init__()
        
        ### first CNN layer
#         self.out_channel1 = trial.suggest_categorical('out_channel1', [32, 64, 128, 256, 512])
        self.out_channel1 = 512
#         kernel_size1 = trial.suggest_categorical('kernel_size1', [13, 15, 17, 19, 21])
        kernel_size1 = 13
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=self.out_channel1,\
                               kernel_size=kernel_size1, stride=1, padding=(kernel_size1 - 1) // 2)

        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)

#         dropout_rate_conv1 = trial.suggest_categorical("dropout_rate_conv1", [0, 0.1, 0.2, 0.4, 0.8])
#         dropout_rate_conv1 = 0
#         self.drop_conv1 = nn.Dropout(p=dropout_rate_conv1)

#         self.maxpool1 = trial.suggest_categorical('maxpool1', [5, 10, 20])
        self.maxpool1 = 5
        self.conv1_out_dim = 400 // self.maxpool1
        
        if self.cnn_num == 2:
        
            ### second CNN layer

#             self.out_channel2 = trial.suggest_categorical('out_channel2', [32, 64, 128, 256, 512])
            self.out_channel2 = 512
#             kernel_size2 = trial.suggest_categorical('kernel_size2', [13, 15, 17, 19, 21])
            kernel_size2 = 21

            self.conv2 = nn.Conv1d(in_channels= self.out_channel1, out_channels=self.out_channel2,\
                                   kernel_size=kernel_size2, stride=1, padding=(kernel_size2 - 1) // 2)

            self.conv2_bn = nn.BatchNorm1d(self.out_channel2)

#             dropout_rate_conv2 = trial.suggest_categorical("dropout_rate_conv2", [0, 0.1, 0.2, 0.4, 0.8])
#             dropout_rate_conv2 = 0
#             self.drop_conv2 = nn.Dropout(p=dropout_rate_conv2)

#             self.maxpool2 = trial.suggest_categorical('maxpool2', [5, 10, 20])
            self.maxpool2 = 5
            self.conv2_out_dim = 400 // (self.maxpool1 * self.maxpool2)
    
    def forward(self, x):
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = F.max_pool1d(out, self.maxpool1)
        if self.cnn_num == 2:
            out = torch.relu(self.conv2_bn(self.conv2(out)))
            out = F.max_pool1d(out, self.maxpool2)
            out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
        else:
            out = out.view(-1, self.out_channel1 * self.conv1_out_dim)
        return out


class RCM_optuna_flanking(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the flanking introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_flanking, self).__init__()
        
        # convlayer 1
#         self.out_channel1 = trial.suggest_categorical('RCM_flanking_out_channel1', [32, 64, 128, 256, 512])
        self.out_channel1 = 128
#         kernel_size1 = 5

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=5, stride=5, padding=0)
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
#         dropout_rate_conv1 = trial.suggest_categorical("RCM_flanking_dropout_rate_conv1", [0, 0.1, 0.2, 0.4, 0.8])
#         dropout_rate_conv1 = 0
#         self.drop_conv1 = nn.Dropout(p=dropout_rate_conv1)
        
        
#         self.out_channel2 = trial.suggest_categorical('RCM_flanking_out_channel2', [32, 64, 128, 256, 512])
        self.out_channel2 = 32
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
#         dropout_rate_conv2 = trial.suggest_categorical("RCM_flanking_dropout_rate_conv2", [0, 0.1, 0.2, 0.4, 0.8])
#         dropout_rate_conv2 = 0
#         self.drop_conv2 = nn.Dropout(p=dropout_rate_conv2)
        
        
        self.conv2_out_dim = 5
    

    def forward(self, x):
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
        return out


    
class RCM_optuna_upper(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the upper introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_upper, self).__init__()
        
        # convlayer 1
#         self.out_channel1 = trial.suggest_categorical('RCM_upper_out_channel1', [32, 64, 128, 256, 512])
        self.out_channel1 = 512
        kernel_size1 = 5

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=kernel_size1, stride=5, padding=0)
        
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
#         dropout_rate_conv1 = trial.suggest_categorical("RCM_upper_dropout_rate_conv1", [0, 0.1, 0.2, 0.4, 0.8])
#         dropout_rate_conv1 = 0
#         self.drop_conv1 = nn.Dropout(p=dropout_rate_conv1)
        
        
#         self.out_channel2 = trial.suggest_categorical('RCM_upper_out_channel2', [32, 64, 128, 256, 512])
        self.out_channel2 = 64
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
#         dropout_rate_conv2 = trial.suggest_categorical("RCM_upper_dropout_rate_conv2", [0, 0.1, 0.2, 0.4, 0.8])
#         dropout_rate_conv2 = 0
#         self.drop_conv2 = nn.Dropout(p=dropout_rate_conv2)
        
        
        self.conv2_out_dim = 5
    

    def forward(self, x):
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
        return out
    
    
class RCM_optuna_lower(nn.Module):
    '''
        This is for 2-d model to process the RCM score distribution of the lower introns
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_lower, self).__init__()
        
        # convlayer 1
#         self.out_channel1 = trial.suggest_categorical('RCM_lower_out_channel1', [32, 64, 128, 256, 512])
        self.out_channel1 = 512
        kernel_size1 = 5

        self.conv1 = nn.Conv1d(in_channels=5, out_channels=self.out_channel1,\
                               kernel_size=kernel_size1, stride=5, padding=0)
        
        self.conv1_bn = nn.BatchNorm1d(self.out_channel1)
#         dropout_rate_conv1 = trial.suggest_categorical("RCM_lower_dropout_rate_conv1", [0, 0.1, 0.2, 0.4, 0.8])
#         dropout_rate_conv1 = 0
#         self.drop_conv1 = nn.Dropout(p=dropout_rate_conv1)
        
        
        
#         self.out_channel2 = trial.suggest_categorical('RCM_lower_out_channel2', [32, 64, 128, 256, 512])
        self.out_channel2 = 512
        self.conv2 = nn.Conv1d(in_channels=self.out_channel1, out_channels=self.out_channel2,\
                               kernel_size=5, stride=5, padding=0)
        
        self.conv2_bn = nn.BatchNorm1d(self.out_channel2)
#         dropout_rate_conv2 = trial.suggest_categorical("RCM_lower_dropout_rate_conv2", [0, 0.1, 0.2, 0.4, 0.8])
#         dropout_rate_conv2 = 0
#         self.drop_conv2 = nn.Dropout(p=dropout_rate_conv2)
        self.conv2_out_dim = 5
    

    def forward(self, x):
        out = x
        out = torch.relu(self.conv1_bn(self.conv1(out)))
        out = torch.relu(self.conv2_bn(self.conv2(out)))

        out = out.view(-1, self.out_channel2 * self.conv2_out_dim)
        return out
        
        

class RCM_optuna_concate(nn.Module):
    ''''
        
    '''
    def __init__(self, trial):
        
        super(RCM_optuna_concate, self).__init__()

        ### cnn for the flanking rcm scores
        self.cnn_flanking = RCM_optuna_flanking(trial)

        self.flanking_out_dim = self.cnn_flanking.conv2_out_dim
        self.flanking_out_channel = self.cnn_flanking.out_channel2
#         print(f'flanking out dim: {self.flanking_out_dim}, flanking out channel {self.flanking_out_channel}')
        
        ### cnn for the upper rcm scores
        self.cnn_upper = RCM_optuna_upper(trial)

        self.upper_out_dim = self.cnn_upper.conv2_out_dim
        self.upper_out_channel = self.cnn_upper.out_channel2
#         print(f'upper_out_dim: {self.upper_out_dim}, upper_out_channel {self.upper_out_channel}')
        
        ### cnn for the lower rcm scores
        self.cnn_lower = RCM_optuna_lower(trial)

        self.lower_out_dim = self.cnn_lower.conv2_out_dim
        self.lower_out_channel = self.cnn_lower.out_channel2
#         print(f'lower_out_dim: {self.lower_out_dim}, lower_out_channel {self.lower_out_channel}')
        

        self.fc1_input_dim = self.flanking_out_dim * self.flanking_out_channel + \
                             self.upper_out_dim * self.upper_out_channel + \
                             self.lower_out_dim * self.lower_out_channel

#         print(f'fc1_input_dim: {self.fc1_input_dim}')
        
        
#         self.fc1_out = trial.suggest_categorical('concat_fc1_out', [32, 64, 128, 256, 512])
        self.fc1_out = 256
    
        # add the rcm feature dimension here as well (5*5+2)*3+2 = 83
        self.fc1 = nn.Linear(self.fc1_input_dim, self.fc1_out)
        
        self.fc1_bn = nn.BatchNorm1d(self.fc1_out)

#         dropout_rate_fc1 = trial.suggest_categorical("concat_dropout_rate_fc1",  [0, 0.1, 0.2, 0.4, 0.8])
        dropout_rate_fc1 = 0
        self.drop_nn1 = nn.Dropout(p=dropout_rate_fc1)

        # fc layer2
        # use dimension output with nn.CrossEntropyLoss()
#         self.fc2_out = trial.suggest_categorical('concat_fc2_out', [8, 16, 32, 64, 128])
        self.fc2_out = 64
        self.fc2 = nn.Linear(self.fc1_out, self.fc2_out)

        self.fc2_bn = nn.BatchNorm1d(self.fc2_out)

#         dropout_rate_fc2 = trial.suggest_categorical("concat_dropout_rate_fc2",[0, 0.1, 0.2, 0.4, 0.8])
        dropout_rate_fc2 = 0
    
        self.drop_nn2 = nn.Dropout(p=dropout_rate_fc2)

#         self.fc3 = nn.Linear(self.fc2_out, 2)
        

    def forward(self, rcm_flanking, rcm_upper, rcm_lower):
        
        x1 = self.cnn_flanking(rcm_flanking)

        x2 = self.cnn_upper(rcm_upper)
        
        x3 = self.cnn_lower(rcm_lower)
        
        x = torch.cat((x1,x2,x3), dim=1)
    
        # feed the concatenated feature to fc1
        out = self.fc1(x)
        out = self.drop_nn1(torch.relu(self.fc1_bn(out)))
        out = self.fc2(out)
        out = self.drop_nn2(torch.relu(self.fc2_bn(out)))
#         out = self.fc3(out)
        return out
    
        

class ConcatModel1_optuna(nn.Module):
    ''''
        
    '''
    def __init__(self, trial, cnn_num, include_rcm_tri_cnn):
        
        self.cnn_num = cnn_num
        
        ### check whether to include mlp to process the rcm and a2i features
        self.include_rcm_tri_cnn = include_rcm_tri_cnn
        
        super(ConcatModel1_optuna, self).__init__()

        ### cnn for the concatenated sequence
        self.cnn = Model1_optuna(trial, self.cnn_num)
        
        if self.cnn_num == 2:
            # this is for two convlayer
            self.out_dim = self.cnn.conv2_out_dim
            self.out_channel = self.cnn.out_channel2
        else:
            # this is for one convlayer
            self.out_dim = self.cnn.conv1_out_dim
            self.out_channel = self.cnn.out_channel1
        
        if self.include_rcm_tri_cnn:
            self.rcm_tri_cnn = RCM_optuna_concate(trial)
            
#             self.rcm_flanking_cnn = RCM_optuna_flanking(trial)
#             self.rcm_upper_cnn = RCM_optuna_upper(trial)
#             self.rcm_lower_cnn = RCM_optuna_lower(trial)
            
            self.fc1_input_dim = self.rcm_tri_cnn.fc2_out + self.out_channel * self.out_dim 
        
#             self.fc1_input_dim = self.out_channel * self.out_dim + \
#                                  self.rcm_flanking_cnn.conv2_out_dim * self.rcm_flanking_cnn.out_channel2 + \
#                                  self.rcm_upper_cnn.conv2_out_dim * self.rcm_upper_cnn.out_channel2 + \
#                                  self.rcm_lower_cnn.conv2_out_dim * self.rcm_lower_cnn.out_channel2
       
            ### just use the junction sequence of singleCNN
        else:
            self.fc1_input_dim = self.out_channel * self.out_dim
            

        self.fc1_out = trial.suggest_categorical('concat_fc1_out', [64, 128, 256, 512])
#         self.fc1_out = 256
    
        # add the rcm feature dimension here as well (5*5+2)*3+2 = 83
        self.fc1 = nn.Linear(self.fc1_input_dim, self.fc1_out)
        
        self.fc1_bn = nn.BatchNorm1d(self.fc1_out)

        dropout_rate_fc1 = trial.suggest_categorical("concat_dropout_rate_fc1",  [0, 0.1, 0.2, 0.4])
#         dropout_rate_fc1 = 0
        self.drop_nn1 = nn.Dropout(p=dropout_rate_fc1)

        # fc layer2
        # use dimension output with nn.CrossEntropyLoss()
        self.fc2_out = trial.suggest_categorical('concat_fc2_out', [8, 16, 32, 64])
#         self.fc2_out = 64
        self.fc2 = nn.Linear(self.fc1_out, self.fc2_out)

        self.fc2_bn = nn.BatchNorm1d(self.fc2_out)

        dropout_rate_fc2 = trial.suggest_categorical("concat_dropout_rate_fc2",[0, 0.1, 0.2, 0.4])
#         dropout_rate_fc2 = 0
    
        self.drop_nn2 = nn.Dropout(p=dropout_rate_fc2)
        
#         fc layer 3
#         self.fc3_out = trial.suggest_categorical('concat_fc3_out', [4, 8, 16, 32])
# # #         self.fc2_out = 64
#         self.fc3 = nn.Linear(self.fc2_out, self.fc3_out)

#         self.fc3_bn = nn.BatchNorm1d(self.fc3_out)

#         dropout_rate_fc3 = trial.suggest_categorical("concat_dropout_rate_fc3",[0, 0.1, 0.2, 0.4])
# #         dropout_rate_fc3 = 0
    
#         self.drop_nn3 = nn.Dropout(p=dropout_rate_fc3)
        
        self.fc3 = nn.Linear(self.fc2_out, 2)
        

    def forward(self, seq_upper_lower_feature, rcm_flanking=None, rcm_upper=None, rcm_lower=None):
        
        x1 = self.cnn(seq_upper_lower_feature)
            
        if self.include_rcm_tri_cnn:
            x2 = self.rcm_tri_cnn(rcm_flanking, rcm_upper, rcm_lower)
            x = torch.cat((x1, x2), dim=1)
            ### normalization after concatenation
            x = torch.nn.functional.normalize(x)
#             x2 = self.rcm_flanking_cnn(rcm_flanking)
#             x3 = self.rcm_upper_cnn(rcm_upper)
#             x4 = self.rcm_lower_cnn(rcm_lower)
#             x = torch.cat((x1,x2,x3,x4), dim=1)
        else:
            x = x1
    
        # feed the concatenated feature to fc1
        out = self.fc1(x)
        out = self.drop_nn1(torch.relu(self.fc1_bn(out)))
        
        out = self.fc2(out)
        out = self.drop_nn2(torch.relu(self.fc2_bn(out)))
        
#         out = self.fc3(out)
#         out = self.drop_nn3(torch.relu(self.fc3_bn(out)))
        
        out = self.fc3(out)
        return out


In [28]:
def Objective(device, trial, fold, model, optimizer,
              patience, epochs, criterion, train_loader, 
              val_loader, model_folder):
#     print(f"I'm in the fold: {fold}")
#     print('here is my model structure', model)
    ### implement the early stopping based on the validation loss change
    last_val_loss = 1000 # set to some big number
    counter = 0 # count the patience so far

    for epoch in range(epochs):
#         print(f"I'am in the epoch {epoch}")
        model.train()
        # record the training loss
        running_loss = 0.0

        ## deal with different number of features in different dataset with star* notation
        for *features, train_labels in train_loader:
            ### this line is just for nn.CrossEntropy loss otherwise can be safely removed
            train_labels = train_labels.type(torch.LongTensor)

            train_labels = train_labels.to(device)
            features = [i.to(device) for i in features]

            # forward pass
            train_preds = model(*features)
            loss = criterion(train_preds, train_labels)
            # backward pass
            optimizer.zero_grad()  # empty the gradient from last round

            # calculate the gradient
            loss.backward()
            # update the parameters
            optimizer.step()
            running_loss += loss.item()
            
#         print(f"I'am finished the epoch {epoch} training")
        ## start model validation
        model.eval()
        
        with torch.no_grad():
            ## first evaluate the training acc
            correct, total = 0.0, 0.0
            for *features, train_labels in train_loader:
                ### this type conversion is just used for nn.CrossEntropy loss
                ### otherwise can be safely removed
                train_labels = train_labels.to(device)
                features = [i.to(device) for i in features]

                # get the predition with the model parameters updated after each epoch
                preds = model(*features)

                # prediction for the nn.CrossEntropy loss
                _, preds_labels = torch.max(preds, 1)
                correct += (preds_labels == train_labels).sum().item()
                total += train_labels.shape[0]

            train_acc = round(correct / total, 4)
            
#             print(f"I'am finished the epoch {epoch} evaluation on the training set")
            
            if (epoch + 1) % 20 == 0:
                print(f'fold {fold + 1}, epoch {epoch + 1}, training loss {running_loss}, train accuracy {train_acc}')

            # evaluate the validation accuracy and other metrics after each epoch
            correct, total = 0.0, 0.0
            val_loss_list = []
            # store the validation prediction and validation label for different metric calculation
            val_labels_list = []
            val_preds_list = []

            for *val_features, val_labels in val_loader:
                ### this type conversion is just used for nn.CrossEntropy loss
                ### otherwise can be safely removed
                val_labels = val_labels.type(torch.LongTensor)

                val_labels = val_labels.to(device)
                val_features = [i.to(device) for i in val_features]

                # get the predition with the model parameters updated after each epoch
                val_preds = model(*val_features)

                # get the validation loss for the early stopping
                val_loss = criterion(val_preds, val_labels)
                val_loss_list.append(val_loss.item())

                # prediction for the nn.CrossEntropy loss
                _, preds_labels = torch.max(val_preds, 1)

                ## append the true and predict value in lists for the calculation of 
                ## model performance
                val_preds_list.append(preds_labels)
                val_labels_list.append(val_labels)

                correct += (preds_labels == val_labels).sum().item()
                total += val_labels.shape[0]
                
            val_acc = round(correct / total, 4)
            
#             print(f"I'am finished the epoch {epoch} evaluation on the validation set")
            
            total_val_loss = np.sum(val_loss_list)
            
            val_labels_total = torch.cat(val_labels_list, dim=0)
            val_preds_total = torch.cat(val_preds_list, dim=0)
            ## calculate the f1 score
#                     f1_score = f1(val_preds_total, val_labels_total)
            
            if (epoch + 1) % 20 == 0:
                print(f'fold {fold + 1}, epoch {epoch + 1}, val accuracy {val_acc}')

        ### early stopping checking and save the model after each epoch if the trial is not pruned
        if total_val_loss <= last_val_loss:
#             print(f"the total val loss is {total_val_loss} on epoch {epoch}")
            last_val_loss = total_val_loss
            best_val_acc = val_acc # save the best val_acc so far
#                 print(f'epoch {epoch+1} total val loss:{total_val_loss}')

            
#                 print(f'epoch {epoch + 1}, val loss {total_val_loss}, val accuracy {val_acc}')
            ## set counter to 0 to start checking if the next 10 consecutive epoches are having reduced val loss
            counter = 0

            ## this line will overwrite the model and save the best one in each fold
            ## with the lowest val loss in each trial
            model_path = f"{model_folder}/fold{fold+1}_trial{trial.number}.pt"
            torch.save(model, model_path)

        else:
#                 print(f'epoch {epoch+1} total val loss:{total_val_loss}')
            counter += 1
#                 print(counter)
            if counter >= patience:
                break # break out of the epoch loop and into the next fold
            
    return best_val_acc ## best validation accuracy in each fold


In [29]:
class Objective_CV:
    
    def __init__(self, patience, cv, model, dataset, cnn_num, is_simple_cnn,
                 is_bicnn, include_rcm_tri_cnn, val_acc_folder, model_folder):
        
        self.patience = patience ## number of epochs for early stopping the model training
        self.cv = cv ## number of CV
        self.model = model ## pass the corresponding model
        self.dataset = dataset ## the corresponding dataset object 
        
        self.cnn_num = cnn_num ## number of CNN layers for either simple_cnn or bi_cnn
        
        ### check which model to initiate
        self.is_simple_cnn = is_simple_cnn ## model that just concatenate junction sequences
        self.is_bicnn = is_bicnn ## model that use bi CNN
        
        self.include_rcm_tri_cnn =  include_rcm_tri_cnn ## whether include rcm_cnn in the CNN
        
        self.val_acc_folder = val_acc_folder ## folder to store the cross_validation accuracy 
        self.model_folder = model_folder ## folder to store the trained model for later testing dataset evaluation
        
    def __call__(self, trial):
        
             ### just use the sequence feature for now
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
#         lr = 0.0003
        l2_lambda = trial.suggest_float("l2_lambda", 1e-8, 1e-3, log=True)
#         l2_lambda = 0


        ### fix and use the maximal allowed batch size
        batch_size = trial.suggest_categorical("batch_size", [128, 256, 512, 1024])
#         batch_size = 512

        ### optimize epoch number
        epochs = 150
        
        criterion = nn.CrossEntropyLoss()
        
        kfold = KFold(n_splits=self.cv, shuffle=True)
        
        val_acc_list = []
        
        for fold, (train_index, val_index) in enumerate(kfold.split(np.arange(len(self.dataset)))):
            
            ### get the train and val loader
            train_subset = torch.utils.data.Subset(self.dataset, train_index)
            train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4)

            val_subset = torch.utils.data.Subset(self.dataset, val_index)
            val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=True, num_workers=4)
            
            ## model should be initilized here for each fold to have a new model with same hyperparameters
            
            ### for the model the process the concatenated upper and lower the is_rcm is always False
            if self.is_simple_cnn: 
                ### create model for simple cnn and include or exclude rcm and a2i mlp as subnetworks
                model = self.model(trial, self.cnn_num, self.include_rcm_tri_cnn).to(device=device)
#                 print(model)
            elif self.is_bicnn:
                ### create model for biCNN and include or exclude rcm and a2i mlp as subnetworks
                model = self.model(trial, self.cnn_num, self.include_rcm_tri_cnn).to(device=device)
#                 print(model)
                
            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_lambda)
            
            accuracy = Objective(device, trial, fold=fold, model=model, optimizer=optimizer,
                                 patience=self.patience, epochs=epochs, criterion=criterion, 
                                 train_loader=train_loader, val_loader=val_loader, 
                                 model_folder=self.model_folder)
            
            val_acc_list.append(accuracy)
            
        ### to speed up training for cv just maximize the val_acc from the 3 fold cv
       ## choose the best model structure and hyperparameters based on average 3-cv validation acc
    
        avg_acc_val = np.mean(val_acc_list)
        
        val_acc_path = f"{self.val_acc_folder}/val_acc.csv"
        
        val_acc_str = '\t'.join([str(i) for i in val_acc_list])
        with open(val_acc_path, 'a') as f:
            f.write('trial' + str(trial.number) + '\t' + val_acc_str + '\n')
            
        return avg_acc_val


In [17]:
BS_LS_coordinates_path = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/updated_data/BS_LS_coordinates_final.csv'
hg19_seq_dict_json_path = '/home/wangc90/circRNA/circRNA_Data/hg19_seq/hg19_seq_dict.json'
flanking_dict_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/'
bs_ls_dataset = BS_LS_DataSet_Prep(BS_LS_coordinates_path=BS_LS_coordinates_path,
                                   hg19_seq_dict_json_path=hg19_seq_dict_json_path,
                                   flanking_dict_folder=flanking_dict_folder,
                                   flanking_junction_bps=100,
                                   flanking_intron_bps=4500,
                                   training_size=11000)


### generate the junction and flanking intron dict
bs_ls_dataset.get_junction_flanking_intron_seq()

### used the same training keys for RCM and A2I mlp structure selection
train_keys, test_keys = bs_ls_dataset.get_train_test_keys()

rcm_scores_folder = '/home/wangc90/circRNA/circRNA_Data/BS_LS_data/flanking_dicts/rcm_scores/'

In [18]:
len(train_keys)+ len(test_keys)

25942

In [19]:
train_keys

array(['chr15|55669146|55759359|-', 'chr5|130977406|131006324|-',
       'chr16|88037900|88098938|+', ..., 'chr4|41526425|41668683|+',
       'chr4|75066979|75147267|+', 'chr2|56419611|56599613|+'],
      dtype='<U27')

In [10]:
train_torch_upper_lower_features,\
train_torch_rcm_flanking_features, train_torch_rcm_upper_features,\
train_torch_rcm_lower_features, train_torch_labels = bs_ls_dataset.seq_to_tensor(data_keys=train_keys, rcm_folder=rcm_scores_folder, is_rcm=True, is_upper_lower_concat=True)


BS_LS_dataset = BS_LS_upper_lower_concat_rcm(include_rcm=True,
                                             seq_upper_lower_feature=train_torch_upper_lower_features,
                                             flanking_rcm=train_torch_rcm_flanking_features,
                                             upper_rcm=train_torch_rcm_upper_features,
                                             lower_rcm=train_torch_rcm_lower_features,
                                             label=train_torch_labels)

In [30]:
def combined_model1_selection_optuna(num_trial, train_keys):

    ### where to save the 3-fold CV validation acc

    val_acc_folder = '/home/wangc90/circRNA/circRNA_Data/model_outputs/combined_model1/val_acc_cv3'
    ### where to save the best model in the 3-fold CV 
    model_folder = '/home/wangc90/circRNA/circRNA_Data/model_outputs/combined_model1/models'
    ### wehre to save the detailed optuna results
    optuna_folder = '/home/wangc90/circRNA/circRNA_Data/model_outputs/combined_model1/optuna'
    
    ## try without rcm features
#     train_torch_upper_lower_features, train_torch_labels = bs_ls_dataset.seq_to_tensor(data_keys=train_keys, rcm_folder=rcm_scores_folder, is_rcm=False, is_upper_lower_concat=True)
    
#     BS_LS_dataset = BS_LS_upper_lower_concat_rcm(include_rcm=False,
#                                           seq_upper_lower_feature=train_torch_upper_lower_features,
#                                           flanking_rcm=None,
#                                           upper_rcm=None,
#                                           lower_rcm=None,
#                                           label=train_torch_labels)
    
    train_torch_upper_lower_features,\
    train_torch_rcm_flanking_features, train_torch_rcm_upper_features,\
    train_torch_rcm_lower_features, train_torch_labels = bs_ls_dataset.seq_to_tensor(data_keys=train_keys, rcm_folder=rcm_scores_folder, is_rcm=True, is_upper_lower_concat=True)

    
    BS_LS_dataset = BS_LS_upper_lower_concat_rcm(include_rcm=True,
                                                 seq_upper_lower_feature=train_torch_upper_lower_features,
                                                 flanking_rcm=train_torch_rcm_flanking_features,
                                                 upper_rcm=train_torch_rcm_upper_features,
                                                 lower_rcm=train_torch_rcm_lower_features,
                                                 label=train_torch_labels)

#     print(len(BS_LS_dataset))
#     print(BS_LS_dataset[0][0].shape, BS_LS_dataset[0][1].shape, BS_LS_dataset[0][2].shape)
    
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=2),
                                direction='maximize')
    
    study.optimize(Objective_CV(patience=10, cv=3, model= ConcatModel1_optuna, 
                            dataset=BS_LS_dataset,
                            cnn_num=2, is_simple_cnn=True,
                            is_bicnn=False, include_rcm_tri_cnn=True,
                            val_acc_folder=val_acc_folder,
                            model_folder=model_folder), n_trials=num_trial, gc_after_trial=True)


    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    with open(optuna_folder+'/optuna.txt', 'a') as f:
        f.write("Study statistics: \n")
        f.write(f"Number of finished trials: {len(study.trials)}\n")
        f.write(f"Number of pruned trials: {len(pruned_trials)}\n")
        f.write(f"Number of complete trials: {len(complete_trials)}\n")

        f.write("Best trial:\n")
        trial = study.best_trial
        f.write(f"Value: {trial.value}\n")
        f.write("Params:\n")
        for key, value in trial.params.items():
            f.write(f"{key}:{value}\n")

    df = study.trials_dataframe().drop(['state','datetime_start','datetime_complete','duration','number'], axis=1)
    df.to_csv(optuna_folder + '/optuna.csv', sep='\t', index=None)

In [31]:
combined_model1_selection_optuna(1, train_keys)

[32m[I 2023-04-24 20:10:04,177][0m A new study created in memory with name: no-name-234debaf-f64e-455f-a890-00c8eafa1bda[0m


fold 1, epoch 20, training loss 13.18324564397335, train accuracy 0.9999
fold 1, epoch 20, val accuracy 0.8206
fold 1, epoch 40, training loss 8.222176551818848, train accuracy 1.0
fold 1, epoch 40, val accuracy 0.8292
fold 1, epoch 60, training loss 5.735739767551422, train accuracy 1.0
fold 1, epoch 60, val accuracy 0.8383
fold 1, epoch 80, training loss 4.022277180105448, train accuracy 1.0
fold 1, epoch 80, val accuracy 0.8395
fold 1, epoch 100, training loss 2.8500091172754765, train accuracy 1.0
fold 1, epoch 100, val accuracy 0.8377
fold 2, epoch 20, training loss 7.838870480656624, train accuracy 1.0
fold 2, epoch 20, val accuracy 0.8309
fold 2, epoch 40, training loss 4.542447112500668, train accuracy 1.0
fold 2, epoch 40, val accuracy 0.8346
fold 2, epoch 60, training loss 3.0808320567011833, train accuracy 1.0
fold 2, epoch 60, val accuracy 0.8364
fold 3, epoch 20, training loss 10.071493342518806, train accuracy 0.9997
fold 3, epoch 20, val accuracy 0.8361
fold 3, epoch 40,

[32m[I 2023-04-24 21:20:49,511][0m Trial 0 finished with value: 0.839 and parameters: {'lr': 1.0854792676433357e-05, 'l2_lambda': 3.944333673705097e-05, 'batch_size': 256, 'concat_fc1_out': 512, 'concat_dropout_rate_fc1': 0.2, 'concat_fc2_out': 32, 'concat_dropout_rate_fc2': 0.1}. Best is trial 0 with value: 0.839.[0m
