In [1]:
## Suppress RDKit warnings
def mute_rdkit():
    from rdkit import RDLogger
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

In [2]:
def assign_value(idx, list_train, list_val, list_test):
    if idx in list_train:
        return 'Training'
    elif idx in list_val:
        return 'Validation'
    elif idx in list_test:
        return 'Test' 
    else:
        return 'Training'

In [3]:
## ================================================================================================
## ==================================== random split ==============================================
## ================================================================================================
def nFoldSplit_random(dataTable, colName_mid='Compound Name', CV=10, rng=666666, hasVal=True):
    ds_size = dataTable.shape[0]
    assert CV*2 < ds_size, f"\tError, the dataset (N={ds_size}) is too small to do a {CV}_fold split! Please decrease the CV value ({CV})"

    dataTable_split = dataTable[[colName_mid]].reset_index(drop=True)
    list_mol_idx = dataTable_split.index.to_numpy()

    # Shuffle the list using random seed
    import numpy as np
    np.random.seed(rng)
    np.random.shuffle(list_mol_idx)

    # Split the list into N sublists
    sublists = np.array_split(list_mol_idx, CV)
    idx_test = sublists[CV-1]
    idx_val = sublists[CV-2] if hasVal else []
    idx_train = [i for i in list_mol_idx if i not in idx_test and i not in idx_val]
    print(f"\tSplit the data (n={len(list_mol_idx)}) into Train({len(idx_train)}), Val({len(idx_val)}), and Test({len(idx_test)})")

    # Apply the function to assign values to the new column 'A'
    dataTable_split[f'Split'] = dataTable_split.index.to_series().apply(lambda x: assign_value(x, idx_train, idx_val, idx_test))
    return dataTable_split


## ================================================================================================
## ==================================== temporal split ============================================
## ================================================================================================
def nFoldSplit_temporal(dataTable, colName_mid='Compound Name', colName_date="Created On", CV=10, hasVal=True):
    ds_size = dataTable.shape[0]
    assert CV*2 < ds_size, f"\tError, the dataset (N={ds_size}) is too small to do a {CV}_fold split! Please decrease the CV value ({CV})"

    dataTable_split = dataTable[[colName_mid, colName_date]].reset_index(drop=True)
    try:
        import pandas as pd
        dataTable_split[colName_date] = dataTable_split[colName_date].str.split(';').str[0]
        dataTable_split["date_formatted"] = pd.to_datetime(dataTable_split[colName_date])
        dataTable_split = dataTable_split.sort_values(by=["date_formatted"], ascending=[True])
    except Exception as e:
        print(f"\tWarning! The mol date column <{colName_date}> cannot be formatted. Error mgs: {e}")
    else:
        # Split the list into N sublists
        import numpy as np
        list_mol_idx = dataTable_split.index.to_numpy()
        try:
            sublists = np.array_split(list_mol_idx, CV)
        except Exception as e:
            print(f"\tWarning! Cannot split data based on date. Error mgs: {e}")
        else:
            idx_test = sublists[CV-1]
            idx_val = sublists[CV-2] if hasVal else []
            idx_train = [i for i in list_mol_idx if i not in idx_test and i not in idx_val]
            print(f"\tSplit the data (n={len(list_mol_idx)}) into Train({len(idx_train)}), Val({len(idx_val)}), and Test({len(idx_test)})")
            
            # Apply the function to assign values to the new column 'A'
            dataTable_split[f'Split'] = dataTable_split.index.to_series().apply(lambda x: assign_value(x, idx_train, idx_val, idx_test))
    return dataTable_split 


## ================================================================================================
## ==================================== diverse split ============================================
## ================================================================================================
def nFoldSplit_diverse(dataTable, colName_mid='Compound Name', colName_smi="Structure", CV=10, hasVal=True):
    ds_size = dataTable.shape[0]
    assert CV*2 < ds_size, f"\tError, the dataset (N={ds_size}) is too small to do a {CV}_fold split! Please decrease the CV value ({CV})"

    dataTable_split = dataTable[[colName_mid, colName_smi]].sample(frac=1).reset_index(drop=True)
    smiles_list = dataTable_split[colName_smi].to_list()

    ## calc the fps
    mute_rdkit()
    import numpy as np
    from rdkit import Chem, DataStructs, SimDivFilters
    from rdkit.Chem import AllChem
    fps = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 3, nBits=2048) for smi in smiles_list]

    ## Generate the distance matrix in advance
    ds=[]
    for i in range(1,len(fps)):
        ds.extend(DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i], returnDistance=True))

    ## Initialize the MaxMinPicker
    picker = SimDivFilters.MaxMinPicker()

    ## define the number of mols to pick for test/validation
    num_picks = int(ds_size/CV)
    num_picks_real = 2*num_picks if hasVal else num_picks
    print(num_picks_real, num_picks)

    ## Select N diverse molecules from the set
    pick_idx = picker.Pick(np.array(ds), len(fps), num_picks_real)
    idx_test = pick_idx[:num_picks] if hasVal else pick_idx
    idx_val = pick_idx[num_picks:] if hasVal else []
    idx_train = [i for i in dataTable_split.index if i not in pick_idx]

    # Apply the function to assign values to the new column 'A'
    dataTable_split[f'Split'] = dataTable_split.index.to_series().apply(lambda x: assign_value(x, idx_train, idx_val, idx_test))

    return dataTable_split

In [4]:
def main(split_method):
    '''
    args = Args_Prepation(parser_desc='Preparing the input files and the descriptors')
    fileNameIn = args.input    # '../../1_DataPrep/results/data_input_clean.csv'
    sep = args.delimiter 
    colName_mid = args.colId    # 'Compound Name'
    colName_smi = args.colSmi    # 'Structure'

    '''
    fileNameIn = './results/data_input_clean.csv'
    sep =  ','
    colName_mid = 'Compound Name'
    colName_date = "ADME MDCK(WT) Permeability;Concat;Run Date"  #'Created On'    #
    colName_smi = 'Structure'
    # split_method = 'random'
    CV = 10
    rng = 666666
    hasVal =True  
    
    ## ------------ load data ------------
    import pandas as pd
    dataTable_raw = pd.read_csv(fileNameIn, sep=sep).head(100)
    print(f"\t{dataTable_raw.shape}")
    assert colName_mid in dataTable_raw.columns, f"\tColumn name for mol ID <{colName_mid}> is not in the table."
        

    print(f"\tData split method: {split_method}")
    ## ------------ calculate rdkit properties ------------
    if split_method == 'random':
        dataTable_split = nFoldSplit_random(dataTable_raw, colName_mid, CV=CV, rng=rng, hasVal=hasVal)

    ## ------------ calculate mol fingerprints ------------
    if split_method == 'temporal':
        assert colName_date in dataTable_raw.columns, f"\tColumn name for date <{colName_date}> is not in the table."
        dataTable_split = nFoldSplit_temporal(dataTable_raw, colName_mid, colName_date, CV=CV, hasVal=hasVal)

    ## ------------ calculate chemAxon properties ------------
    if split_method == 'diverse':
        assert colName_smi in dataTable_raw.columns, f"\tColumn name for mol smiles <{colName_smi}> is not in the table."
        dataTable_split = nFoldSplit_diverse(dataTable_raw, colName_mid, colName_smi, CV=10, hasVal=hasVal)


    ## ------------ save the split ------------
    import os
    output_folder = './results'
    os.makedirs(output_folder, exist_ok=True)
    dataTable_split.to_csv(f"{output_folder}/data_split_{split_method}.csv", index=False)


In [5]:
for split_method in ['random', 'temporal', 'diverse']:
    main(split_method=split_method)

	(100, 12)
	Data split method: random
	Split the data (n=100) into Train(80), Val(10), and Test(10)
	(100, 12)
	Data split method: temporal
	Split the data (n=100) into Train(80), Val(10), and Test(10)
	(100, 12)
	Data split method: diverse
20 10


In [None]:
# import pandas as pd
# dataTable = pd.read_csv('./results/data_input_clean.csv')
# dataTable_split = nFoldSplit_diverse(dataTable, colName_mid='Compound Name', colName_smi="Structure", CV=10, hasVal=False)
# dataTable_split['Split'].value_counts()

# import pandas as pd
# dataTable = pd.read_csv('./results/data_input_clean.csv')
# dataTable_split = nFoldSplit_random(dataTable, colName_mid='Compound Name', CV=10, rng=666666, hasVal=False)
# dataTable_split['Split'].value_counts()

# import pandas as pd
# dataTable = pd.read_csv('./results/data_input_clean.csv')
# dataTable_split = nFoldSplit_temporal(dataTable, colName_mid='Compound Name', colName_date='ADME MDCK(WT) Permeability;Concat;Run Date', CV=10, hasVal=False)
# dataTable_split['Split'].value_counts()

In [12]:
!/mnt/data0/Research/0_Test/cx_pKa/bash2py_yjing_local.bash python ./Data_Split.py -i "./results/data_input_clean.csv" -d "," --colId 'Compound Name' --colSmi 'Structure' --colDate "ADME MDCK(WT) Permeability;Concat;Run Date" --split 'diverse' --CV 10 --rng 666666 --hasVal True

	(3830, 12)
	Data split method: diverse
	Split the data (n=3830) into Train(3064), Val(383), and Test(383)


In [None]:
# if __name__ == '__main__':
#     main()