In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split
import torch.optim as optim
import pandas as pd
import numpy as np
import os
from collections import defaultdict, Counter
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import KFold
import optuna
from torchmetrics.classification import F1Score
import pickle
import sys
import warnings
warnings.filterwarnings("ignore")
import random
from sklearn.model_selection import ShuffleSplit
from sklearn.cluster import KMeans,DBSCAN,Birch
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, v_measure_score

sys.path.insert(1, '/home/wangc90/Data_integration/MOCSS/mocss/code/')
from critic import LinearCritic
# from Supcon import SupConLoss
from contrastive_loss import InstanceLoss, ClusterLoss
import evaluation
from sklearn import metrics
from Data_prep import DataSet_Prep, DataSet_construction
random.seed(2023)
torch.manual_seed(2023)


<torch._C.Generator at 0x7fe5cc06ac50>

In [2]:
class SSO_Encoder(nn.Module):
    """
        takes in 3 omic data type measurements for the same set of subjects
    """
    
    def __init__(self, trial):
        
        
        possible_nodes = [32, 64, 128, 256, 512, 1024]
        possible_dropout = [0, 0.1, 0.2, 0.4, 0.6]
        
        ### input dimension for omic 1, omic 2 and omic 3
        self.s1_input_dim = 20531
        self.s2_input_dim = 1046
        
        self.l1_s1_out_dim = trial.suggest_categorical("l1_s1_out_dim", possible_nodes)
        self.l1_s2_out_dim = trial.suggest_categorical("l1_s2_out_dim", possible_nodes)
        self.l1_s12_out_dim = trial.suggest_categorical("l1_s12_out_dim", possible_nodes)
        
        
        self.l2_s1_out_dim = trial.suggest_categorical("l2_s1_out_dim", possible_nodes)
        self.l2_s2_out_dim = trial.suggest_categorical("l2_s2_out_dim", possible_nodes)
        self.l2_s12_out_dim = trial.suggest_categorical("l2_s12_out_dim", possible_nodes)

        
        self.l3_s1_out_dim = trial.suggest_categorical("l3_s1_out_dim", possible_nodes)
        self.l3_s2_out_dim = trial.suggest_categorical("l3_s2_out_dim", possible_nodes)
        self.l3_s12_out_dim = trial.suggest_categorical("l3_s12_out_dim", possible_nodes)
        
        ### embedding for z1, z2 and z12 have to have the same dimension for the 
        ### orthogonal losss based on MOCSS to work
        self.embed_s1_out_dim = trial.suggest_categorical("embed_s1_out_dim", possible_nodes)
        self.embed_s2_out_dim = self.embed_s1_out_dim
        self.embed_s12_out_dim = self.embed_s1_out_dim
        
        
        
        super(SSO_Encoder, self).__init__()
        
        ### encoder structure: 
        
        ######################################################################################
        self.l1_s1 = nn.Linear(self.s1_input_dim, self.l1_s1_out_dim)
        self.l1_s1_bn = nn.BatchNorm1d(self.l1_s1_out_dim)
        l1_s1_drop_rate = trial.suggest_categorical("l1_s1_drop_rate", possible_dropout)
        self.drop_l1_s1 = nn.Dropout(p=l1_s1_drop_rate)
        
        
        self.l2_s1 = nn.Linear(self.l1_s1_out_dim,  self.l2_s1_out_dim)
        self.l2_s1_bn = nn.BatchNorm1d(self.l2_s1_out_dim)
        l2_s1_drop_rate = trial.suggest_categorical("l2_s1_drop_rate", possible_dropout)
        self.drop_l2_s1 = nn.Dropout(p=l2_s1_drop_rate)
        
        
        self.l3_s1 = nn.Linear(self.l2_s1_out_dim, self.l3_s1_out_dim)
        self.l3_s1_bn = nn.BatchNorm1d(self.l3_s1_out_dim)
        l3_s1_drop_rate = trial.suggest_categorical("l3_s1_drop_rate", possible_dropout)
        self.drop_l3_s1 = nn.Dropout(p=l3_s1_drop_rate)
        
        
        self.embed_s1 = nn.Linear(self.l3_s1_out_dim, self.embed_s1_out_dim)
        self.embed_s1_bn = nn.BatchNorm1d(self.embed_s1_out_dim)
        embed_s1_drop_rate = trial.suggest_categorical("embed_s1_drop_rate", possible_dropout)
        self.drop_embed_s1 = nn.Dropout(p=embed_s1_drop_rate)
        
        
        ###########################################################################################
        self.l1_s2 = nn.Linear(self.s2_input_dim, self.l1_s2_out_dim)
        self.l1_s2_bn = nn.BatchNorm1d(self.l1_s2_out_dim)
        l1_s2_drop_rate = trial.suggest_categorical("l1_s2_drop_rate", possible_dropout)
        self.drop_l1_s2 = nn.Dropout(p=l1_s2_drop_rate)
        
        
        self.l2_s2 = nn.Linear(self.l1_s2_out_dim,  self.l2_s2_out_dim)
        self.l2_s2_bn = nn.BatchNorm1d(self.l2_s2_out_dim)
        l2_s2_drop_rate = trial.suggest_categorical("l2_s2_drop_rate", possible_dropout)
        self.drop_l2_s2 = nn.Dropout(p=l2_s2_drop_rate)
        
        
        self.l3_s2 = nn.Linear(self.l2_s2_out_dim, self.l3_s2_out_dim)
        self.l3_s2_bn = nn.BatchNorm1d(self.l3_s2_out_dim)
        l3_s2_drop_rate = trial.suggest_categorical("l3_s2_drop_rate", possible_dropout)
        self.drop_l3_s2 = nn.Dropout(p=l3_s2_drop_rate)
        
        
        self.embed_s2 = nn.Linear(self.l3_s2_out_dim, self.embed_s2_out_dim)
        self.embed_s2_bn = nn.BatchNorm1d(self.embed_s2_out_dim)
        embed_s2_drop_rate = trial.suggest_categorical("embed_s2_drop_rate", possible_dropout)
        self.drop_embed_s2 = nn.Dropout(p=embed_s2_drop_rate)
        
        
        
        ##########################################################################################
        
        self.l1_s12 = nn.Linear(self.s1_input_dim + self.s2_input_dim,
                                 self.l1_s12_out_dim)
        self.l1_s12_bn = nn.BatchNorm1d(self.l1_s12_out_dim)
        l1_s12_drop_rate = trial.suggest_categorical("l1_s12_drop_rate", possible_dropout)
        self.drop_l1_s12 = nn.Dropout(p=l1_s12_drop_rate)
        
        
        self.l2_s12 = nn.Linear(self.l1_s12_out_dim,  self.l2_s12_out_dim)
        self.l2_s12_bn = nn.BatchNorm1d(self.l2_s12_out_dim)
        l2_s12_drop_rate = trial.suggest_categorical("l2_s12_drop_rate", possible_dropout)
        self.drop_l2_s12 = nn.Dropout(p=l2_s12_drop_rate)
        
        
        self.l3_s12 = nn.Linear(self.l2_s12_out_dim, self.l3_s12_out_dim)
        self.l3_s12_bn = nn.BatchNorm1d(self.l3_s12_out_dim)
        l3_s12_drop_rate = trial.suggest_categorical("l3_s12_drop_rate", possible_dropout)
        self.drop_l3_s12 = nn.Dropout(p=l3_s12_drop_rate)
        
        
        self.embed_s12 = nn.Linear(self.l3_s12_out_dim, self.embed_s12_out_dim)
        self.embed_s12_bn = nn.BatchNorm1d(self.embed_s12_out_dim)
        embed_s12_drop_rate = trial.suggest_categorical("embed_s12_drop_rate", possible_dropout)
        self.drop_embed_s12 = nn.Dropout(p=embed_s12_drop_rate)
        

    def forward(self, s1, s2, labels=None):
        
        #############################################################
        s1_ = self.drop_l1_s1(self.l1_s1_bn(F.relu(self.l1_s1(s1))))
        s1_ = self.drop_l2_s1(self.l2_s1_bn(F.relu(self.l2_s1(s1_))))
        s1_ = self.drop_l3_s1(self.l3_s1_bn(F.relu(self.l3_s1(s1_))))
        z1 = self.drop_embed_s1(self.embed_s1_bn(F.relu(self.embed_s1(s1_))))
        
        
        
        s2_ = self.drop_l1_s2(self.l1_s2_bn(F.relu(self.l1_s2(s2))))
        s2_ = self.drop_l2_s2(self.l2_s2_bn(F.relu(self.l2_s2(s2_))))
        s2_ = self.drop_l3_s2(self.l3_s2_bn(F.relu(self.l3_s2(s2_))))
        z2 = self.drop_embed_s2(self.embed_s2_bn(F.relu(self.embed_s2(s2_))))
        
        
        
        
        ### concatenate s1, s2 together for the joint embedding
        s12 = torch.cat((s1, s2), dim=1)
        s12_ = self.drop_l1_s12(self.l1_s12_bn(F.relu(self.l1_s12(s12))))
        s12_ = self.drop_l2_s12(self.l2_s12_bn(F.relu(self.l2_s12(s12_))))
        s12_ = self.drop_l3_s12(self.l3_s12_bn(F.relu(self.l3_s12(s12_))))
        z12 = self.drop_embed_s12(self.embed_s12_bn(F.relu(self.embed_s12(s12_))))
        
        
        return z1, z2, z12, labels
    
    

class SSO_Decoder(nn.Module):
    
    ### decoder: construct s1 and s2  based on the concatenated z12 z1 and z2
    ### and calculate the reconstruction loss separately for s1 and s2
    
    def __init__(self, trial):
        
        
        possible_nodes = [32, 64, 128, 256, 512, 1024]
        possible_dropout = [0, 0.1, 0.2, 0.4, 0.6]
        
        self.s1_input_dim = SSO_Encoder(trial).s1_input_dim
        self.s2_input_dim = SSO_Encoder(trial).s2_input_dim
        
        self.s1_embed_dim = SSO_Encoder(trial).embed_s1_out_dim
        self.s2_embed_dim = SSO_Encoder(trial).embed_s2_out_dim
        self.s12_embed_dim = SSO_Encoder(trial).embed_s12_out_dim
        
        
        self._embed_s1_out_dim = trial.suggest_categorical("_embed_s1_out_dim", possible_nodes)
        self._l3_s1_out_dim = trial.suggest_categorical("_l3_s1_out_dim", possible_nodes)
        self._l2_s1_out_dim = trial.suggest_categorical("_l2_s1_out_dim", possible_nodes)
        self._l1_s1_out_dim = self.s1_input_dim
        
        
        self._embed_s2_out_dim = trial.suggest_categorical("_embed_s2_out_dim", possible_nodes)
        self._l3_s2_out_dim = trial.suggest_categorical("_l3_s2_out_dim", possible_nodes)
        self._l2_s2_out_dim = trial.suggest_categorical("_l2_s2_out_dim", possible_nodes)
        self._l1_s2_out_dim = self.s2_input_dim
        
        
        super(SSO_Decoder, self).__init__()
        
        
        self._embed_s1 = nn.Linear(self.s1_embed_dim + self.s2_embed_dim + self.s12_embed_dim,\
                                   self._embed_s1_out_dim)
        
        self._embed_s1_bn = nn.BatchNorm1d(self._embed_s1_out_dim)
        _embed_s1_drop_rate = trial.suggest_categorical("_embed_s1_drop_rate", possible_dropout)
        self._drop_embed_s1 = nn.Dropout(p=_embed_s1_drop_rate)
        
        
        self._l3_s1 = nn.Linear(self._embed_s1_out_dim, self._l3_s1_out_dim)
        self._l3_s1_bn = nn.BatchNorm1d(self._l3_s1_out_dim)
        _l3_s1_drop_rate = trial.suggest_categorical("_l3_s1_drop_rate", possible_dropout)
        self._drop_l3_s1 = nn.Dropout(p=_l3_s1_drop_rate)
        
        
        self._l2_s1 = nn.Linear(self._l3_s1_out_dim, self._l2_s1_out_dim)
        self._l2_s1_bn = nn.BatchNorm1d(self._l2_s1_out_dim)
        _l2_s1_drop_rate = trial.suggest_categorical("_l2_s1_drop_rate", possible_dropout)
        self._drop_l2_s1 = nn.Dropout(p=_l2_s1_drop_rate)
        
        
        self._l1_s1 = nn.Linear(self._l2_s1_out_dim, self._l1_s1_out_dim)
        self._l1_s1_bn = nn.BatchNorm1d(self._l1_s1_out_dim)
        _l1_s1_drop_rate = trial.suggest_categorical("_l1_s1_drop_rate", possible_dropout)
        self._drop_l1_s1 = nn.Dropout(p=_l1_s1_drop_rate)
        
        #############################################################################

        self._embed_s2 = nn.Linear(self.s1_embed_dim + self.s2_embed_dim + self.s12_embed_dim,\
                                   self._embed_s2_out_dim)
        
        self._embed_s2_bn = nn.BatchNorm1d(self._embed_s2_out_dim)
        _embed_s2_drop_rate = trial.suggest_categorical("_embed_s2_drop_rate", possible_dropout)
        self._drop_embed_s2 = nn.Dropout(p=_embed_s2_drop_rate)
        
        
        self._l3_s2 = nn.Linear(self._embed_s2_out_dim, self._l3_s2_out_dim)
        self._l3_s2_bn = nn.BatchNorm1d(self._l3_s2_out_dim)
        _l3_s2_drop_rate = trial.suggest_categorical("_l3_s2_drop_rate", possible_dropout)
        self._drop_l3_s2 = nn.Dropout(p=_l3_s2_drop_rate)
        
        
        self._l2_s2 = nn.Linear(self._l3_s2_out_dim, self._l2_s2_out_dim)
        self._l2_s2_bn = nn.BatchNorm1d(self._l2_s2_out_dim)
        _l2_s2_drop_rate = trial.suggest_categorical("_l2_s2_drop_rate", possible_dropout)
        self._drop_l2_s2 = nn.Dropout(p=_l2_s2_drop_rate)
        
        
        self._l1_s2 = nn.Linear(self._l2_s2_out_dim, self._l1_s2_out_dim)
        self._l1_s2_bn = nn.BatchNorm1d(self._l1_s2_out_dim)
        _l1_s2_drop_rate = trial.suggest_categorical("_l1_s2_drop_rate", possible_dropout)
        self._drop_l1_s2 = nn.Dropout(p=_l1_s2_drop_rate)
        
        
        
    def forward(self, z1, z2, z12):
        
        z_all = torch.cat((z1, z2, z12), dim=1)
        
        
        s1_ = self._drop_embed_s1(self._embed_s1_bn(F.relu(self._embed_s1(z_all))))
        s1_ = self._drop_l3_s1(self._l3_s1_bn(F.relu(self._l3_s1(s1_))))
        s1_ = self._drop_l2_s1(self._l2_s1_bn(F.relu(self._l2_s1(s1_))))
        s1_ = self._drop_l1_s1(self._l1_s1_bn(F.relu(self._l1_s1(s1_))))
        
        s1_out = torch.sigmoid(s1_)
        
        
        
        s2_ = self._drop_embed_s2(self._embed_s2_bn(F.relu(self._embed_s2(z_all))))
        
        s2_ = self._drop_l3_s2(self._l3_s2_bn(F.relu(self._l3_s2(s2_))))
        s2_ = self._drop_l2_s2(self._l2_s2_bn(F.relu(self._l2_s2(s2_))))
        s2_ = self._drop_l1_s2(self._l1_s2_bn(F.relu(self._l1_s2(s2_))))
        
        s2_out = torch.sigmoid(s2_)
        
        return s1_out, s2_out

    
class SSO_AE(nn.Module):
    def __init__(self, trial):
        
        super(SSO_AE, self).__init__()
        
        self.encoder = SSO_Encoder(trial)
        self.decoder = SSO_Decoder(trial)

    def forward(self, s1, s2, labels):
        
        ### encoder ouput for embeddings
        z1, z2, z12, labels = self.encoder(s1, s2, labels)
        
        ### decoder output for reconstructed input
        s1_out, s2_out = self.decoder(z1, z2, z12)
        
        return z1, z2, z12, s1_out, s2_out, labels

In [3]:
# SSO_AE()

In [4]:
def CustomLoss(s1, s2, s1_out, s2_out,
                z1, z2, z12, labels):

    """
        Ortho_AE1
    """

    ### normalize the feature vector with length 1
    s1_out = F.normalize(s1_out, p=2, dim=1)
    s2_out = F.normalize(s2_out, p=2, dim=1)
    
    
    s1 = F.normalize(s1, p=2, dim=1)
    s2 = F.normalize(s2, p=2, dim=1)

    recon_loss = torch.linalg.matrix_norm(s1_out-s1) + torch.linalg.matrix_norm(s2_out-s2)

    z1 = F.normalize(z1, p=2, dim=1)
    z2 = F.normalize(z2, p=2, dim=1)
    z12 = F.normalize(z12, p=2, dim=1)
    
    z12_z1 = torch.mul(z12, z1)
    z12_z2 = torch.mul(z12, z2)
    
    z12_z1_cost = z12_z1.mean()
    z12_z2_cost = z12_z2.mean()
    
    ortho_loss = z12_z1_cost + z12_z2_cost
    
    return recon_loss, ortho_loss

In [5]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

In [6]:
### create a training function that is seamless for all different models
def Objective(device, trial, fold, model, optimizer,
              epochs, train_loader, val_loader, ortho_multiplier):

    
    #### model training on the training set
    
    for epoch in range(epochs):
        
        model.train()
        # record the training loss
        total_recon_loss = 0.0
        total_ortho_loss = 0.0
        total_train = 0.0
        
        
        for iteration_index, train_batch in enumerate(train_loader):
            view1_train_data, view2_train_data, train_labels = train_batch

            view1_train_data = view1_train_data.type(torch.float32).to(device)
            view2_train_data = view2_train_data.type(torch.float32).to(device)
            train_labels = train_labels.type(torch.LongTensor).to(device)
            
            z1, z2, z12, s1_out, s2_out, labels = \
                model(view1_train_data, view2_train_data, train_labels)
            
#             print(z123.shape, s1_out.shape, s2_out.shape, s3_out.shape)
            train_size = z12.size()[0]
    
            recon_loss, ortho_loss = CustomLoss(s1=view1_train_data,\
                                             s2=view2_train_data,\
                                             s1_out=s1_out,\
                                             s2_out=s2_out,\
                                             z1=z1,\
                                             z2=z2,\
                                             z12=z12,\
                                             labels=train_labels)
            
#             print(clustering_loss)
            
            loss = recon_loss + (ortho_loss * ortho_multiplier)
            # backward pass
            optimizer.zero_grad()  # empty the gradient from last round

            # calculate the gradient
            loss.backward()
            
            
#             print(model.weight.grad)
            # update the parameters
            optimizer.step()
            
            total_train += train_size
#             total_clustering_loss += (clustering_loss.item()*train_size)
            total_recon_loss += recon_loss.item()
            total_ortho_loss += ortho_loss.item()

        
        if (epoch + 1) % 10 == 0: 
            print(f'fold {fold+1} epoch {epoch+1}')
            print(f'average train recon loss is: {total_recon_loss/total_train}')
            print(f'average train ortho loss is: {total_ortho_loss/total_train}')
            
            
    #### model evaluation on the validation set at the last epoch and return acc
    
    model.eval()
    with torch.no_grad():

        total_recon_loss = 0.0
        total_val = 0.0

        for iteration_index, val_batch in enumerate(val_loader):
            view1_val_data, view2_val_data, val_labels = val_batch

            view1_val_data = view1_val_data.type(torch.float32).to(device)
            view2_val_data = view2_val_data.type(torch.float32).to(device)
            val_labels = val_labels.type(torch.LongTensor).to(device)

            z1_val, z2_val, z12_val, s1_out_val, s2_out_val, labels_val = \
                model(view1_val_data, view2_val_data, val_labels)

            val_size = z12.size()[0]

            recon_loss, ortho_loss = CustomLoss(s1=view1_val_data,\
                                         s2=view2_val_data,\
                                         s1_out=s1_out_val,\
                                         s2_out=s2_out_val,\
                                         z1=z1_val,\
                                         z2=z2_val,\
                                         z12=z12_val,\
                                         labels=labels_val)

            total_val += val_size
            total_recon_loss += recon_loss.item()
        avg_recon_loss = total_recon_loss / total_val
    return avg_recon_loss



class Objective_CV:

    def __init__(self, cv, model, dataset, val_loss_folder):

        self.cv = cv  ## number of CV
        self.model = model  ## pass the corresponding model
        self.dataset = dataset  ## the corresponding dataset object
        self.val_loss_folder = val_loss_folder  ## folder to store the cross_validation accuracy

    def __call__(self, trial):

        ### just use the sequence feature for now
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
#         lr = 0.001
        l2_lambda = trial.suggest_float("l2_lambda", 1e-8, 1e-5, log=True)
#         l2_lambda = 0
        batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256, 512])
    
        ### optimize the multiplier for orthogonal loss
        ortho_multiplier = trial.suggest_categorical("ortho_multiplier", [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3])

        ### optimize epoch number
        epochs = trial.suggest_categorical("epoch", [30, 60, 90, 120, 150])
        
        ## choose the optimizer
#         optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop","Adagrad"])
        optimizer_name = 'Adam'

        kfold = KFold(n_splits=self.cv, shuffle=True)

        setup_seed(21)
        
        val_fold_loss = []
                    
        for fold, (train_index, val_index) in enumerate(kfold.split(np.arange(len(self.dataset)))):

            ### get the train and val loader
            train_subset = torch.utils.data.Subset(self.dataset, train_index)
            train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
            

            val_subset = torch.utils.data.Subset(self.dataset, val_index)
            val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=True)
            

            ## model should be initilized here for each fold to have a new model with same hyperparameters

            ### for the model the process the concatenated upper and lower the is_rcm is always False
            Model = self.model(trial).to(device=device)
#             print(Model)
            optimizer = getattr(optim, optimizer_name)(Model.parameters(), lr=lr, weight_decay=l2_lambda)

            val_loss = Objective(device, trial, fold=fold, model=Model, optimizer=optimizer,
                                 epochs=epochs, train_loader=train_loader, val_loader=val_loader,
                                    ortho_multiplier=ortho_multiplier)

            val_fold_loss.append(val_loss)
            
        avg_val_loss = np.mean(val_fold_loss)

        val_loss_path = f"{self.val_loss_folder}/val_loss.csv"

        val_loss_str = '\t'.join([str(i) for i in val_fold_loss])
        with open(val_loss_path, 'a') as f:
            f.write('trial' + str(trial.number) + '\t' + val_loss_str + '\n')

        return avg_val_loss


In [7]:
def Ortho_AE_model_selection(num_trial):

    ### where to save the 3-fold CV validation acc

    val_loss_folder = '/home/wangc90/Data_integration/TCGA_model_outputs/model_selection_outputs/Ortho_AE/val_loss'
    
    ### wehre to save the detailed optuna results
    optuna_folder = '/home/wangc90/Data_integration/TCGA_model_outputs/model_selection_outputs/Ortho_AE/optuna'
    
    
    combined_exp_df = pd.read_csv('/home/wangc90/Data_integration/TCGA_data/TCGA_primary_tumor_data/combined_exp_df.csv', sep='\t')
    combined_miRNA_df = pd.read_csv('/home/wangc90/Data_integration/TCGA_data/TCGA_primary_tumor_data/combined_miRNA_df.csv', sep='\t')

    labels = pd.read_csv('/home/wangc90/Data_integration/TCGA_data/TCGA_primary_tumor_data/labels.csv', sep='\t')['0']

    dataset_prep = DataSet_Prep(data1=combined_exp_df, data2=combined_miRNA_df, label=labels, training_prop=0.8)

    train_key, test_key = dataset_prep.get_train_test_keys()

    feature1_tensors, feature2_tensors, label_tensors = dataset_prep.to_tensor(train_key)
    
    train_dataset = DataSet_construction(feature1_tensors, feature2_tensors, label_tensors)

    print(len(train_dataset))
    
    
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=2),
                                direction='minimize')
    
    
    study.optimize(Objective_CV(cv=5, model= SSO_AE,
                                dataset=train_dataset,
                                val_loss_folder=val_loss_folder),
                   n_trials=num_trial, gc_after_trial=True)


    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    with open(optuna_folder+'/optuna.txt', 'a') as f:
        f.write("Study statistics: \n")
        f.write(f"Number of finished trials: {len(study.trials)}\n")
        f.write(f"Number of pruned trials: {len(pruned_trials)}\n")
        f.write(f"Number of complete trials: {len(complete_trials)}\n")

        f.write("Best trial:\n")
        trial = study.best_trial
        f.write(f"Value: {trial.value}\n")
        f.write("Params:\n")
        for key, value in trial.params.items():
            f.write(f"{key}:{value}\n")

    df = study.trials_dataframe().drop(['state','datetime_start','datetime_complete','duration','number'], axis=1)
    df.to_csv(optuna_folder + '/optuna.csv', sep='\t', index=None)

In [None]:
import time

# Record the start time
start_time = time.time()


Ortho_AE_model_selection(num_trial=100)

# Record the end time
end_time = time.time()

# Calculate the elapsed time
elapsed_time = end_time - start_time

print(f"Elapsed time: {elapsed_time/60} minutes")

[32m[I 2024-02-22 22:52:57,552][0m A new study created in memory with name: no-name-3496bcf5-a362-40f9-89c6-6113f2f2180a[0m


feature1 and feature2 are being scaled with MinMaxScaler
1494
fold 1 epoch 10
average train recon loss is: 0.09375804757473359
average train ortho loss is: 1.0580542910392066e-06
fold 1 epoch 20
average train recon loss is: 0.09188606549506408
average train ortho loss is: 8.574921479955253e-07
fold 1 epoch 30
average train recon loss is: 0.09054481474425503
average train ortho loss is: 1.0027854594257836e-06
fold 2 epoch 10
average train recon loss is: 0.0934993089492351
average train ortho loss is: 1.0682194674909737e-06
fold 2 epoch 20
average train recon loss is: 0.09165584292870685
average train ortho loss is: 1.1690315317019634e-06
fold 2 epoch 30
average train recon loss is: 0.0904821004827651
average train ortho loss is: 1.0176358654876883e-06
fold 3 epoch 10
average train recon loss is: 0.09368992889276609
average train ortho loss is: 8.588714503671459e-07
fold 3 epoch 20
average train recon loss is: 0.09199450185608166
average train ortho loss is: 6.863567539759635e-07
fold 3 

[32m[I 2024-02-22 22:57:58,221][0m Trial 0 finished with value: 0.18130907753641293 and parameters: {'lr': 0.002092733733374216, 'l2_lambda': 6.756429983729922e-07, 'batch_size': 512, 'ortho_multiplier': 0.001, 'epoch': 30, 'l1_s1_out_dim': 128, 'l1_s2_out_dim': 128, 'l1_s12_out_dim': 128, 'l2_s1_out_dim': 128, 'l2_s2_out_dim': 1024, 'l2_s12_out_dim': 1024, 'l3_s1_out_dim': 1024, 'l3_s2_out_dim': 256, 'l3_s12_out_dim': 512, 'embed_s1_out_dim': 256, 'l1_s1_drop_rate': 0, 'l2_s1_drop_rate': 0, 'l3_s1_drop_rate': 0.6, 'embed_s1_drop_rate': 0.6, 'l1_s2_drop_rate': 0.2, 'l2_s2_drop_rate': 0.6, 'l3_s2_drop_rate': 0.1, 'embed_s2_drop_rate': 0.2, 'l1_s12_drop_rate': 0, 'l2_s12_drop_rate': 0, 'l3_s12_drop_rate': 0, 'embed_s12_drop_rate': 0.4, '_embed_s1_out_dim': 1024, '_l3_s1_out_dim': 256, '_l2_s1_out_dim': 512, '_embed_s2_out_dim': 128, '_l3_s2_out_dim': 1024, '_l2_s2_out_dim': 64, '_embed_s1_drop_rate': 0.6, '_l3_s1_drop_rate': 0.4, '_l2_s1_drop_rate': 0, '_l1_s1_drop_rate': 0.4, '_embed_

fold 1 epoch 10
average train recon loss is: 0.18690629743631915
average train ortho loss is: 2.3914811410356673e-07
fold 1 epoch 20
average train recon loss is: 0.18637019520524156
average train ortho loss is: 2.8207340692004354e-07
fold 1 epoch 30
average train recon loss is: 0.18553696317153995
average train ortho loss is: 3.460403353391823e-07
fold 1 epoch 40
average train recon loss is: 0.1846133595231188
average train ortho loss is: 4.2536507044513573e-07
fold 1 epoch 50
average train recon loss is: 0.18364768566945608
average train ortho loss is: 4.772583235086382e-07
fold 1 epoch 60
average train recon loss is: 0.18263988734289194
average train ortho loss is: 6.301374326627624e-07
fold 2 epoch 10
average train recon loss is: 0.18682116105466706
average train ortho loss is: 1.9733839949690852e-07
fold 2 epoch 20
average train recon loss is: 0.18626511226638093
average train ortho loss is: 3.100316348060903e-07
fold 2 epoch 30
average train recon loss is: 0.18545851128869476
aver