In [1]:
# conda env: datacat (Python 3.8.20)
# adapted from FS-Mol: https://github.com/microsoft/FS-Mol/blob/92aa95daba3f43863227e65be85a07b4a2ee754f/notebooks/ExtractDataset.ipynb#L371

# Setup

In [2]:
import os
import sys
import json

import csv
from datacat4ml.const import CURA_CAT_GPCR_DIR, SPLIT_DATA_DIR

# 3. Assay Selection for train-valid-split

In [94]:
def generate_or_test(input_gpcr_list=LHDs_min32, input_dir=LHDs_dir):
    """Generate the OR test set out of the whole GPCR datasets for testing in the few-shot tasks."""
    or_list = [f for f in input_gpcr_list if f.startswith(('CHEMBL233_', 'CHEMBL236_', 'CHEMBL237_', 'CHEMBL2014_'))]
    print(f'len(or_list): {len(or_list)}')

    test_list_low = []
    test_list_middle = []
    test_list_high = []
    for f in or_list:
        f_path = os.path.join(input_dir, f)
        df = pd.read_csv(f_path)
        percentage_active = df['activity'].sum() / df.shape[0]

        # if the percentage of active data points is between 30% and 70%, we consider the assay to be balanced
        if percentage_active < 0.3:
            test_list_low.append(f)
        elif 0.3 <= percentage_active <= 0.7:
            test_list_middle.append(f)
        elif percentage_active > 0.7:
            test_list_high.append(f)
    print(f'len(test_list_low): {len(test_list_low)}')
    print(f'len(test_list_middle): {len(test_list_middle)}')
    print(f'len(test_list_high): {len(test_list_high)}')

    return test_list_middle

In [95]:
test_or_lhds = generate_or_test(input_gpcr_list=LHDs_min32, input_dir=LHDs_dir)
test_or_mhds = generate_or_test(input_gpcr_list=MHDs_min32, input_dir=os.path.join(CURA_CAT_GPCR_DIR, 'cls'))

len(or_list): 106
len(test_list_low): 0
len(test_list_middle): 21
len(test_list_high): 85
len(or_list): 32
len(test_list_low): 1
len(test_list_middle): 4
len(test_list_high): 27


In [None]:
few_shot_split = {
    'test': test_or_lhds,
    'train': [f for f in LHDs_min32 if f not in test_or_lhds],
    'valid':
}