# Process data for Fibrosis



def data_splitting(data_df,type_of_split, 
                   random_state=42,
                   butina_cutoff=0.5,
                   Hi_cutoff=0.5,hi_threshold=0.4):
    '''Function that splits data and returns list of indices 
    
        data_df : dataframe with SMILES and Activity
        
        type_of_split : "random", "random_stratified","butina","LoHi_Hi_split" 
        
        butina_cutoff: default = 0.5
        
        Hi_cutoff:default = 0.5
        
        Hi_threshold:default = 0.4
        
        random_state = 42 (default)
   '''
    if type_of_split == "LoHi_Hi_split":
        hi_split_df=data_df[["SMILES"]]
        hi_split_df = data_df[["SMILES"]].rename(columns={"SMILES": "smiles"})
        train_indices, test_indices= get_hi_split(hi_split_df,
                                                       threshold = hi_threshold,
                                                       cutoff=Hi_cutoff,
                                                       seed=42)
        return(train_indices.index.to_list(), test_indices.index.to_list())

    elif type_of_split == "random":
        train_indices, test_indices = train_test_split(data_df['SMILES'], 
                                                       test_size=0.2, 
                                                       random_state=random_state)
        return(train_indices.index.to_list(), test_indices.index.to_list())
        
    elif type_of_split == "random_stratified":
        train_indices, test_indices = train_test_split(data_df['SMILES'], 
                                                       test_size=0.2, 
                                                       stratify=data_df['Activity'],
                                                       random_state=random_state)
        return(train_indices.index.to_list(), test_indices.index.to_list())
        
    elif type_of_split == "butina":
        train_indices, test_indices = butina_split(data_df["SMILES"],
                                                   cutoff=butina_cutoff,
                                                   seed = random_state)
    
        return(train_indices, test_indices)

In [None]:
DATA_FILE_PATH = '