In [5]:
# conda env: datacat (Python 3.8.20)
import os
import csv
import random
import json
from collections import defaultdict
import shutil

import pandas as pd
from datacat4ml.const import CURA_CAT_GPCR_DIR, CURA_LHD_GPCR_DIR, CURA_CAT_50_5k_GPCR_DIR, CURA_CAT_OR_DIR, CURA_LHD_OR_DIR, CURA_CAT_50_5K_OR_DIR
from datacat4ml.const import OR_chemblids, SPLIT_DATA_DIR
from datacat4ml.Scripts.data_prep.data_curate.curate_utils.apply_thresholds import apply_thresholds

# Read MHDs from `CURA_CAT_GPCR_DIR` and `CURA_CAT_OR_DIR`

In [None]:
cat_lhd_dic = {
    CURA_CAT_GPCR_DIR: [CURA_CAT_50_5k_GPCR_DIR, CURA_LHD_GPCR_DIR],
    CURA_CAT_OR_DIR: [CURA_CAT_50_5K_OR_DIR, CURA_LHD_OR_DIR]
}

def write_stats_file(in_dir: str, task: str, ds_list: list, ds_str: str):
    """
    Write a stats file for the datasets in ds_list, which is a list of dataset filenames in one subfolder of 'data_curate' folder.

    params:
    --------
    in_dir: the input directory, e.g. CURA_CAT_GPCR_DIR or CURA_CAT_OR_DIR
    task: 'cls' or 'reg'
    ds_list: a list of dataset filenames. e.g. MHDs, MHDs_50_5k, or LHDs
    ds_str: the string to indicate which type of datasets, either 'MHDs_50_5k' or 'LHDs'
    """

    # out_dir depends on ds_str
    if ds_str == 'MHDs_50_5k':
        out_dir = os.path.join(cat_lhd_dic[in_dir][0], task) # e.g. CURA_CAT_50_5k_OR_DIR
    elif ds_str == 'LHDs':
        out_dir = os.path.join(cat_lhd_dic[in_dir][1], task) # e.g. CURA_LHD_OR_DIR
    
    # stats file
    stats_file = os.path.join(out_dir, f'{task}_stats.csv')
    if not os.path.exists(stats_file): # don't use check_file_exists() and then remove the file if it exists
        os.makedirs(os.path.dirname(stats_file), exist_ok=True)
        with open(stats_file, 'w') as f:
            f.write('task,target,effect,assay,standard_type,assay_chembl_id,threshold,datasize,num_active,num_inactive,%_active\n')

    for ds in ds_list:
        segs = ds.split('_')
        target = segs[0]
        effect = segs[1]
        assay = segs[2]
        std_type = segs[3]

        if ds_str == 'MHDs_50_5k':
            assay_chembl_id = 'None'
        elif ds_str == 'LHDs':
            assay_chembl_id = segs[4]

        ds_path = os.path.join(out_dir, ds)
        ds_df = pd.read_csv(ds_path)

        ds_size = ds_df.shape[0]
        threshold = ds_df['threshold'].unique()[0]
        num_active = sum(ds_df['activity'])
        percent_a = round(num_active / ds_size * 100, 2)

        with open(stats_file, 'a') as f:
            f.write(f'{task},{target},{effect},{assay},{std_type},{assay_chembl_id},{threshold},{ds_size},{num_active},{ds_size - num_active},{percent_a}\n')

In [7]:
def read_MHDs_generate_LHDs(in_dir:str=CURA_CAT_OR_DIR, task:str = 'cls') -> list:
    """
    Read the MHDs from the in_dir and filter out the MHDs that contains over 50 data points and smaller than 5000 data points, 
    and generate LHDs from the corresponding MHDs
    
    Args:
        in_dir (str): The path to the input directory containing the MHDs. The options are: CURA_CAT_GPCR_DIR, CURA_CAT_OR_DIR
        task (str): 'cls' or 'reg'.
    
    Returns:
        MHDs: A list of all MHDs in the input directory.
        MHDs_50_5k: A list of MHDs that contains over 50 data points and smaller than 5000 data points.
        LHDs: A list of all LHDs generated from the MHDs, already filtered to be between 50-5k data points.
    """
    #==================== Read the MHDs  =========================
    # Put the file names of all the csv files in the input directory into a list
    
    MHDs = [f for f in os.listdir(os.path.join(in_dir, task)) if f.startswith('CHEMBL')] 
    print(f"Found {len(MHDs)} MHDs in {in_dir}/{task}")

    # Get the unique target_chembl_id in the list of MHDs
    MHDs_tgt = list(set([f.split('_')[0] for f in MHDs]))
    print(f"Found {len(MHDs_tgt)} unique target_chembl_id in {in_dir}/{task}")

    # ==================== Generate MHDs_50_5k =========================
    MHDs_50_5k = []
    LHDs = []
    for mhd in MHDs:
        mhd_path = os.path.join(in_dir, task, mhd)
        # count rows using csv reader
        with open(mhd_path, 'r', newline='') as f:
            row_count = sum(1 for row in csv.reader(f))

        # Subtract 1 for for header row
        if row_count -1 >= 50 and row_count - 1 <= 5000:
            MHDs_50_5k.append(mhd)

            # ===================== Write MHDs_50_5k CSV files =========================
            mhd_df = pd.read_csv(mhd_path).drop(columns=['Unnamed: 0'])
            mhd_50_5k_path = os.path.join(cat_lhd_dic[in_dir][0], task, mhd)
            os.makedirs(os.path.dirname(mhd_50_5k_path), exist_ok=True)
            mhd_df.to_csv(mhd_50_5k_path, index=False)

            #==================== Generate LHDs =========================
            try:
                # Get counts and filter valid IDs
                id_counts = mhd_df['assay_chembl_id'].value_counts()

                # the number of data points in a single assay should,on the one hand, be at least 50 to ensure the model can be trained;
                # on the other hand, should not exceed 5000 to avoid high-throughput screens, as these are generally considered noisy
                valid_ids = id_counts[(id_counts >= 50) & (id_counts <= 5000)].index.tolist()

                if not valid_ids:
                    print(f"No valid IDs found for {mhd}. Skipping...")
                    continue

                 # Save the valid IDs to a new CSV file
                for valid_id in valid_ids:
                    lhd_df = mhd_df[mhd_df['assay_chembl_id'] == valid_id]
                    # delete the old threshold column
                    lhd_df = lhd_df.drop(columns=['threshold', 'activity_string', 'activity'])
                    # apply thresholds again because the new data may have different thresholds
                    lhd_df = apply_thresholds(lhd_df)
                    
                    lhd_name = f"{mhd[:-12]}_{valid_id}_curated.csv"
                    LHDs.append(lhd_name)  # Append to LHDs list

                    # save to csv
                    lhd_path = os.path.join(cat_lhd_dic[in_dir][1], task, lhd_name)
                    os.makedirs(os.path.dirname(lhd_path), exist_ok=True)
                    lhd_df.to_csv(lhd_path, index=False)

            except Exception as e:
                print(f"Error processing {mhd}: {e}")

    print(f"Found {len(MHDs_50_5k)} MHDs with over 50 data points and smaller than 5000 data points in {in_dir}/{task}")

    # ==================== Write stats_file for MHDs_50_5k ========================
    write_stats_file(in_dir, task, MHDs_50_5k, 'MHDs_50_5k')

    #==================== Write stats_file for LHDs (already between 50-5k) ========================
    write_stats_file(in_dir, task, LHDs, 'LHDs')

    return MHDs, MHDs_50_5k, LHDs

In [8]:
# or_cls
cat_or_cls_MHDs, cat_or_cls_MHDs_50_5k, cat_or_cls_LHDs = read_MHDs_generate_LHDs(CURA_CAT_OR_DIR, task='cls')
# or_reg
cat_or_reg_MHDs, cat_or_reg_MHDs_50_5k, cat_or_reg_LHDs = read_MHDs_generate_LHDs(CURA_CAT_OR_DIR, task='reg')

Found 38 MHDs in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_ors/cls
Found 4 unique target_chembl_id in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_ors/cls
No valid IDs found for CHEMBL233_agon_G-Ca_EC50_curated.csv. Skipping...
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
No valid IDs found for CHEMBL237_agon_G-cAMP_EC50_curated.csv. Skipping...
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
No valid IDs found for CHEMBL236_agon_G-GTP_EC50_curated.csv. Skipping...
Applying thresholds 
No valid IDs found for CHEMBL2014_agon_G-GTP_EC50_curated.csv. Skipping...
No valid IDs found for CHEMBL233_bind_RBA_IC50_curated.csv. Skipping..

In [9]:
# gpcr_cls
cat_gpcr_cls_MHDs, cat_gpcr_cls_MHDs_50_5k, cat_gpcr_cls_LHDs = read_MHDs_generate_LHDs(CURA_CAT_GPCR_DIR, 'cls')
# gpcr_reg
cat_gpcr_reg_MHDs, cat_gpcr_reg_MHDs_50_5k, cat_gpcr_reg_LHDs = read_MHDs_generate_LHDs(CURA_CAT_GPCR_DIR, 'reg')

Found 934 MHDs in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_gpcrs/cls
Found 238 unique target_chembl_id in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_gpcrs/cls
No valid IDs found for CHEMBL210_bind_RBA_Ki_curated.csv. Skipping...
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
No valid IDs found for CHEMBL233_agon_G-Ca_EC50_curated.csv. Skipping...
Applying thresholds 
No valid IDs found for CHEMBL229_bind_RBA_IC50_curated.csv. Skipping...
No valid IDs found for CHEMBL1075162_agon_G-Ca_EC50_curated.csv. Skipping...
No valid IDs found for CHEMBL3974_bind_RBA_IC50_curated.csv. Skipping...
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
No valid IDs found for CHEMBL4761_agon_G-Ca_EC50_curated.csv. Skipping...
Applying thresholds 
Apply

# train-valid-test split (cls)

In [11]:
print(f'======== cls ==============')
print(f'The length of cat_or_cls_MHDs is {len(cat_or_cls_MHDs)}') # => all as the test dataset
print(f'The length of cat_or_cls_MHDs_50_5k is {len(cat_or_cls_MHDs_50_5k)}')
print(f'The length of cat_or_cls_LHDs is {len(cat_or_cls_LHDs)}') 

print(f'The length of cat_gpcr_cls_MHDs is {len(cat_gpcr_cls_MHDs)}') # => all except the or as the pretraining
print(f'The length of cat_gpcr_cls_MHDs_50_5k is {len(cat_gpcr_cls_MHDs_50_5k)}')
print(f'The length of cat_gpcr_cls_LHDs is {len(cat_gpcr_cls_LHDs)}') 

print(f'========= reg ===========')
print(f'The length of cat_or_reg_MHDs is {len(cat_or_reg_MHDs)}') # => all as the test dataset
print(f'The length of cat_or_reg_MHDs_50_5k is {len(cat_or_reg_MHDs_50_5k)}')
print(f'The length of cat_or_reg_LHDs is {len(cat_or_reg_LHDs)}') 

print(f'The length of cat_gpcr_reg_MHDs is {len(cat_gpcr_reg_MHDs)}') # => all except the or as the pretraining
print(f'The length of cat_gpcr_reg_MHDs_50_5k is {len(cat_gpcr_reg_MHDs_50_5k)}')
print(f'The length of cat_gpcr_reg_LHDs is {len(cat_gpcr_reg_LHDs)}') 

The length of cat_or_cls_MHDs is 38
The length of cat_or_cls_MHDs_50_5k is 28
The length of cat_or_cls_LHDs is 26
The length of cat_gpcr_cls_MHDs is 934
The length of cat_gpcr_cls_MHDs_50_5k is 490
The length of cat_gpcr_cls_LHDs is 434
The length of cat_or_reg_MHDs is 38
The length of cat_or_reg_MHDs_50_5k is 28
The length of cat_or_reg_LHDs is 26
The length of cat_gpcr_reg_MHDs is 955
The length of cat_gpcr_reg_MHDs_50_5k is 503
The length of cat_gpcr_reg_LHDs is 453


In [13]:
def split_OR_holdout(gpcr_ds, or_ds, ratio=100, seed=None):
    """
    Splits a list of csv files (either MHDs or LHDs) into train, valid, and test based on OR_chemblids and a given ratio.

    Parameters:
        datasets (list): List of datasets (csv files: MHDs or LHDs), e.g. 'CHEMBL233_bind_RBA_IC50_curated.csv'.
        ratio (int): Ratio of train to valid (e.g., 100 means 100:1 train:valid split).
        seed (int, optional): Random seed for reproducibility.

    Returns:
        dict: A dictionary with keys 'train', 'valid', and 'test', each containing a list of dataset.
    """
    if seed is not None:
        random.seed(seed)

    DataFold = defaultdict(list)

    # Split into initial train/test
    for assay in or_ds:
        DataFold['test'].append(assay)
    for assay in gpcr_ds:
        if not any(assay.startswith(x + '_') for x in OR_chemblids):
            DataFold['train'].append(assay)

    # Shuffle for randomness
    random.shuffle(DataFold['train'])

    # Calculate validation set size
    n_total = len(DataFold['train'])
    n_valid = max(1, n_total // (ratio + 1))  # ratio:1 split

    # Split into train/valid
    DataFold['valid'] = DataFold['train'][:n_valid]
    DataFold['train'] = DataFold['train'][n_valid:]

    return DataFold

In [14]:
# ClsMHDsFold: based on the MHDs because all files except the OR related MHDs should be used during pretraining
ClsMHDsFold = split_OR_holdout(cat_gpcr_cls_MHDs, cat_or_cls_MHDs, ratio=100, seed=42)
with open(os.path.join(SPLIT_DATA_DIR, 'ClsMHDsFold.json'), 'w') as f:
    json.dump(ClsMHDsFold, f, indent=2)

# RegMHDsFold
RegMHDsFold = split_OR_holdout(cat_gpcr_reg_MHDs, cat_or_reg_MHDs, ratio=100, seed=42)
with open(os.path.join(SPLIT_DATA_DIR, 'RegMHDsFold.json'), 'w') as f:
    json.dump(RegMHDsFold, f, indent=2)

In [15]:
ClsMHDsFold
print(f'len(ClsMHDsFold): {len(ClsMHDsFold)}')
print(f'len(ClsMHDsFold["train"]): {len(ClsMHDsFold["train"])}; len(ClsMHDsFold["valid"]): {len(ClsMHDsFold["valid"])}; len(ClsMHDsFold["test"]): {len(ClsMHDsFold["test"])}')

len(ClsMHDsFold): 3
len(ClsMHDsFold["train"]): 887; len(ClsMHDsFold["valid"]): 8; len(ClsMHDsFold["test"]): 38


In [16]:
RegMHDsFold
print(f'len(RegMHDsFold): {len(RegMHDsFold)}')
print(f'len(RegMHDsFold["train"]): {len(RegMHDsFold["train"])}; len(RegMHDsFold["valid"]): {len(RegMHDsFold["valid"])}; len(RegMHDsFold["test"]): {len(RegMHDsFold["test"])}')

len(RegMHDsFold): 3
len(RegMHDsFold["train"]): 907; len(RegMHDsFold["valid"]): 9; len(RegMHDsFold["test"]): 38
