In [1]:
# conda env: datacat (Python 3.8.20)
import os
import csv
import random
import json
from collections import defaultdict
import shutil

import pandas as pd
from datacat4ml.const import CURA_CAT_GPCR_DIR, CURA_LHD_GPCR_DIR, CURA_CAT_OR_DIR, CURA_LHD_OR_DIR, OR_chemblids, SPLIT_DATA_DIR
from datacat4ml.Scripts.data_prep.data_curate.utils.apply_thresholds import apply_thresholds

# Read MHDs from `CURA_CAT_GPCR_DIR` and `CURA_CAT_OR_DIR`

In [12]:
cat_lhd_dic = {
    CURA_CAT_GPCR_DIR: CURA_LHD_GPCR_DIR,
    CURA_CAT_OR_DIR: CURA_LHD_OR_DIR
}

def read_MHDs_generate_LHDs(in_dir:str=CURA_CAT_OR_DIR, task:str = 'cls') -> list:
    """
    Read the MHDs from the input path and filter out the MHDs that contains over 32 data points, and generate LHDs from the MHDs.
    Args:
        in_dir (str): The path to the input directory containing the MHDs. The options are: CURA_CAT_GPCR_DIR, CURA_CAT_OR_DIR
    Returns:
        MHDs_min32: A list of MHDs that contains over 32 data points.
        LHDs_min32: A list of LHDs that contains over 32 data points.
    """
    #==================== Read the MHDs =========================
    # Put the file names of all the csv files in the input directory into a list
    
    MHDs = [f for f in os.listdir(os.path.join(in_dir, task)) if f.startswith('CHEMBL')] 
    print(f"Found {len(MHDs)} MHDs in {in_dir}/{task}")

    # Get the unique target_chembl_id in the list of MHDs
    MHDs_tgt = list(set([f.split('_')[0] for f in MHDs]))
    print(f"Found {len(MHDs_tgt)} unique target_chembl_id in {in_dir}/{task}")

    # Filter out the MHDs that contains over 32 data points
    MHDs_min32 = []
    for mhd in MHDs:
        assay_path = os.path.join(in_dir, task, mhd)
        # count rows using csv reader
        with open(assay_path, 'r', newline='') as f:
            row_count = sum(1 for row in csv.reader(f))

        # Subtract 1 for for header row
        if row_count -1 >= 32:
            MHDs_min32.append(mhd)
    print(f"Found {len(MHDs_min32)} MHDs with over 32 data points in {in_dir}/{task}")

    #==================== Generate LHDs =========================
    LHDs_dict = defaultdict(list)
    out_dir = os.path.join(cat_lhd_dic[in_dir], task)
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir)

    for assay in MHDs_min32:
        assay_path = os.path.join(in_dir, task, assay)

        try:
            # Read the CSV files
            assay_df = pd.read_csv(assay_path).drop(columns=['Unnamed: 0'])

            # Get counts and filter valid IDs
            id_counts = assay_df['assay_chembl_id'].value_counts()

            # the number of data points in a single assay should,on the one hand, be at least 32 to ensure the model can be trained;
            # on the other hand, should not exceed 5000 to avoid high-throughput screens, as these are generally considered noisy
            valid_ids = id_counts[(id_counts >= 32) & (id_counts <= 5000)].index.tolist()

            if not valid_ids:
                print(f"No valid IDs found for {assay}. Skipping...")
                continue

            # Save the valid IDs to a new CSV file
            for assay_chembl_id in valid_ids:
                df = assay_df[assay_df['assay_chembl_id'] == assay_chembl_id]
                # delete the old threshold column
                df = df.drop(columns=['threshold', 'activity_string', 'activity'])
                # apply thresholds again because the new data may have different thresholds
                df = apply_thresholds(df)

                # save to csv
                basename = os.path.splitext(assay)[0]
                output_fname = f"{basename[:-8]}_{assay_chembl_id}_curated.csv"
                output_path = os.path.join(out_dir, output_fname)
                df.to_csv(output_path, index=False)

                # Add to dictionary
                LHDs_dict[basename].append(assay_chembl_id)

        except Exception as e:
            print(f"Error processing {assay}: {e}")

    print(f"{len(LHDs_dict)} MHDs csv files contains qualified LHDs in {out_dir}/{task}")

    LHDs = [f for f in os.listdir(out_dir) if f.startswith('CHEMBL')]
    LHDs_min32 = []


    for lhd in LHDs:
        lhd_path = os.path.join(out_dir, lhd)

        with open(lhd_path, 'r', newline='') as f:
            row_count = sum(1 for row in csv.reader(f))


        if row_count - 1 >= 32:
            LHDs_min32.append(lhd)
    print(f"Found {len(LHDs_min32)} LHDs with over 32 data points in {out_dir}/{task}")

    return MHDs, LHDs_min32

In [13]:
cat_or_cls_MHDs, cat_or_cls_LHDs = read_MHDs_generate_LHDs(CURA_CAT_OR_DIR, task='cls')
cat_or_reg_MHDs, cat_or_reg_LHDs = read_MHDs_generate_LHDs(CURA_CAT_OR_DIR, task='reg')

Found 38 MHDs in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_ors/cls
Found 4 unique target_chembl_id in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_ors/cls
Found 32 MHDs with over 32 data points in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_ors/cls
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
No valid IDs found for CHEMBL236_agon_G_Ca_EC50_curated.csv. Skipping...
Applying thres

In [14]:
cat_gpcr_cls_MHDs, cat_gpcr_cls_LHDs = read_MHDs_generate_LHDs(CURA_CAT_GPCR_DIR, 'cls')
cat_gpcr_reg_MHDs, cat_gpcr_reg_LHDs = read_MHDs_generate_LHDs(CURA_CAT_GPCR_DIR, 'reg')

Found 935 MHDs in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_gpcrs/cls
Found 238 unique target_chembl_id in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_gpcrs/cls
Found 573 MHDs with over 32 data points in /storage/homefs/yc24j783/datacat4ml/datacat4ml/Data/data_prep/data_curate/cura_cat_gpcrs/cls
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
No valid IDs found for CHEMBL229_bind_RBA_IC50_curated.csv. Skipping...
Applying thresholds 
Applying thresholds 
Applying thresholds 
Applying thresholds 
No valid IDs found for CHEMBL1628461_antag_G_Ca_IC50_curated.csv. Skipping...
Applying thresholds 
Applying thresholds 
Applying th

# train-valid-test split (cls)

In [15]:
print(f'======== cls ==============')
print(f'The length of cat_or_cls_MHDs is {len(cat_or_cls_MHDs)}') # => all as the test dataset
print(f'The length of cat_gpcr_cls_MHDs is {len(cat_gpcr_cls_MHDs)}') # => all except the or as the pretraining

print(f'The length of cat_or_cls_LHDs is {len(cat_or_cls_LHDs)}') 
print(f'The length of cat_gpcr_cls_LHDs is {len(cat_gpcr_cls_LHDs)}') 

print(f'========= reg ===========')
print(f'The length of cat_or_reg_MHDs is {len(cat_or_reg_MHDs)}') # => all as the test dataset
print(f'The length of cat_gpcr_reg_MHDs is {len(cat_gpcr_reg_MHDs)}') # => all except the or as the pretraining

print(f'The length of cat_or_reg_LHDs is {len(cat_or_reg_LHDs)}') 
print(f'The length of cat_gpcr_reg_LHDs is {len(cat_gpcr_reg_LHDs)}') 

The length of cat_or_cls_MHDs is 38
The length of cat_gpcr_cls_MHDs is 935
The length of cat_or_cls_LHDs is 109
The length of cat_gpcr_cls_LHDs is 1516
The length of cat_or_reg_MHDs is 38
The length of cat_gpcr_reg_MHDs is 955
The length of cat_or_reg_LHDs is 109
The length of cat_gpcr_reg_LHDs is 1562


In [16]:
def split_OR_holdout(cat_gpcr_cls_MHDs, cat_or_cls_MHDs, ratio=100, seed=None):
    """
    Splits a list of csv files (either MHDs or LHDs) into train, valid, and test based on OR_chemblids and a given ratio.

    Parameters:
        datasets (list): List of datasets (csv files: MHDs or LHDs), e.g. 'CHEMBL233_bind_RBA_IC50_curated.csv'.
        ratio (int): Ratio of train to valid (e.g., 100 means 100:1 train:valid split).
        seed (int, optional): Random seed for reproducibility.

    Returns:
        dict: A dictionary with keys 'train', 'valid', and 'test', each containing a list of dataset.
    """
    if seed is not None:
        random.seed(seed)

    DataFold = defaultdict(list)

    # Split into initial train/test
    for assay in cat_or_cls_MHDs:
        DataFold['test'].append(assay)
    for assay in cat_gpcr_cls_MHDs:
        if not any(assay.startswith(x + '_') for x in OR_chemblids):
            DataFold['train'].append(assay)

    # Shuffle for randomness
    random.shuffle(DataFold['train'])

    # Calculate validation set size
    n_total = len(DataFold['train'])
    n_valid = max(1, n_total // (ratio + 1))  # ratio:1 split

    # Split into train/valid
    DataFold['valid'] = DataFold['train'][:n_valid]
    DataFold['train'] = DataFold['train'][n_valid:]

    return DataFold

In [17]:
# ClsMHDsFold: based on the MHDs because all files except the OR related MHDs should be used during pretraining
ClsMHDsFold = split_OR_holdout(cat_gpcr_cls_MHDs, cat_or_cls_MHDs, ratio=100, seed=42)
with open(os.path.join(SPLIT_DATA_DIR, 'ClsMHDsFold.json'), 'w') as f:
    json.dump(ClsMHDsFold, f, indent=2)

# RegMHDsFold
RegMHDsFold = split_OR_holdout(cat_gpcr_reg_MHDs, cat_or_reg_MHDs, ratio=100, seed=42)
with open(os.path.join(SPLIT_DATA_DIR, 'RegMHDsFold.json'), 'w') as f:
    json.dump(RegMHDsFold, f, indent=2)

In [18]:
ClsMHDsFold
print(f'len(ClsMHDsFold): {len(ClsMHDsFold)}')
print(f'len(ClsMHDsFold["train"]): {len(ClsMHDsFold["train"])}; len(ClsMHDsFold["valid"]): {len(ClsMHDsFold["valid"])}; len(ClsMHDsFold["test"]): {len(ClsMHDsFold["test"])}')

len(ClsMHDsFold): 3
len(ClsMHDsFold["train"]): 888; len(ClsMHDsFold["valid"]): 8; len(ClsMHDsFold["test"]): 38


In [19]:
RegMHDsFold
print(f'len(RegMHDsFold): {len(RegMHDsFold)}')
print(f'len(RegMHDsFold["train"]): {len(RegMHDsFold["train"])}; len(RegMHDsFold["valid"]): {len(RegMHDsFold["valid"])}; len(RegMHDsFold["test"]): {len(RegMHDsFold["test"])}')

len(RegMHDsFold): 3
len(RegMHDsFold["train"]): 907; len(RegMHDsFold["valid"]): 9; len(RegMHDsFold["test"]): 38
