In [1]:
import pandas as pd
import numpy as np
import os
from collections import namedtuple
import random
from tqdm import tqdm
import itertools

In [2]:
AXVIdx = namedtuple('AXVIndex', ('tridx', 'compound_out', 'both_out'))
def MultipleAXVSplit(df, n_dataset=3):
    
    axvsplit = dict() 
    
    n_generated_set = 0
    for i in range(1000):
        axv = AXV_generator(df, seed=i)
        idxs = AXVIdx(axv.get_subset_idx('train'),
                      axv.get_subset_idx('compound_out'),
                      axv.get_subset_idx('both_out')
                      )
        if IsValidSplit(df, idxs):
            axvsplit[i] = idxs
            n_generated_set += 1
        
        if n_generated_set >= 2:
            if not IsValidSets(df, axvsplit):
                _ = axvsplit.pop(i)
                n_generated_set -= 1
                
        if n_generated_set==n_dataset:
            break
        
    return axvsplit

class AXV_generator():
    
    def __init__(self, data, keep_out_rate=0.2, seed=0, cols=['chembl_cid1', 'chembl_cid2']) -> None:
        self.data = data
        self.rate = keep_out_rate
        self.seed = seed
        self.cols = cols
        self.whole_cpds = np.union1d(np.unique(data[cols[0]]),np.unique(data[cols[1]])).tolist()
        
        self.keepout = self._get_keepout()
        self.identifier = self._set_identifier()
        
         
    def _get_keepout(self):
        random.seed(self.seed)
        sample_size = int(len(self.whole_cpds)*self.rate)
        keep_out = random.sample(self.whole_cpds, sample_size)
        
        return keep_out
    
    def _identifier(self, mmp:pd.Series):
    
        cpd1 = mmp[self.cols[0]]
        cpd2 = mmp[self.cols[1]]
        
        isin_cpd1 = cpd1 in self.keepout
        isin_cpd2 = cpd2 in self.keepout
        
        if (isin_cpd1==False) and (isin_cpd2==False):
            return 0
        
        elif (isin_cpd1==True) and (isin_cpd2==True):
            return 2
        
        elif (isin_cpd1==True) or (isin_cpd2==True):
            return 1
        
    def _set_identifier(self):
        return [self._identifier(sr) for i, sr in self.data.iterrows()]
    
    def get_subset_idx(self, name):
        
        if name.lower() == 'train':   
            mask = [True if i==0 else False for i in self.identifier]
            
        elif name.lower() == 'compound_out':
            mask = [True if i==1 else False for i in self.identifier]
            
        elif name.lower() == 'both_out':
            mask = [True if i==2 else False for i in self.identifier]
            
        return mask    
    

def IsValidSplit(data, idxs):
    
    tr      = data.loc[idxs.tridx  , :]
    cpdout  = data.loc[idxs.compound_out , :]
    bothout = data.loc[idxs.both_out, :]
    
    pos_tr      = tr[tr['class']==1]
    pos_cpdout  = cpdout[cpdout['class']==1]
    pos_bothout = bothout[bothout['class']==1] 
    
    flag_tr      = True if pos_tr.shape[0] > 0 else False
    flag_cpdout  = True if pos_cpdout.shape[0] > 0 else False
    flag_bothout = True if pos_bothout.shape[0] > 0 else False
        
    return bool(flag_tr * flag_cpdout * flag_bothout)

def IsValidSets(data, dict_idxs):
    list_cpdout  = []
    list_bothout = []
    
    for seed in dict_idxs.keys():    
        idxs    = dict_idxs[seed] 
        cpdout  = data.loc[idxs.compound_out , :]
        bothout = data.loc[idxs.both_out, :]
        
        pos_cpdout  = cpdout[cpdout['class']==1].index
        pos_bothout = bothout[bothout['class']==1].index 
        
        list_cpdout.append([pos_cpdout])
        list_bothout.append([pos_bothout])
    
    flag_cpdout = all([True if np.setdiff1d(i, j).shape[0] > 0 else False for i,j in itertools.combinations(pos_cpdout, 2)])
    flag_bothout = all([True if np.setdiff1d(i, j).shape[0] > 0 else False for i,j in itertools.combinations(pos_bothout, 2)])
    
    return bool(flag_cpdout * flag_bothout)


In [3]:
def axv_single_stats(d_idx, data, trial):
    idxs = d_idx[trial]
    tr = data.loc[idxs.tridx, :]
    cpdout = data.loc[idxs.compound_out, :]
    bothout = data.loc[idxs.both_out, :]
    
    return tr.shape[0],tr[tr['class']==1].shape[0], cpdout.shape[0], cpdout[cpdout['class']==1].shape[0], bothout.shape[0], bothout[bothout['class']==1].shape[0], tr.shape[0]/data.shape[0], cpdout.shape[0]/data.shape[0], bothout.shape[0]/data.shape[0] 

def axv_stats(tname):
    col = ['n_tr', 'n_ac_tr', 'n_cpdout', 'n_ac_cpdout', 'n_bothout', 'n_ac_bothout', 'prop_tr', 'prop_cpdout', 'prop_bothout']
    d_stats = dict()
    data = pd.read_csv('./Dataset/Data/%s.tsv' %tname, sep='\t', index_col='id')
    
    d_idx = MultipleAXVSplit(data)
    
    for i in d_idx.keys():
        d_stats['%s-Seed%d'%(tname, i)] = {c:val for c, val in zip(col, axv_single_stats(d_idx, data, i))}
        
    return d_stats

In [4]:
tlist = pd.read_csv('./Dataset/target_list.tsv', sep='\t', index_col='chembl_tid')
tlist = tlist[tlist['predictable_trtssplit']]

stats = dict()

for t in tqdm(tlist.index):
    stats.update(axv_stats(t))
    
stats = pd.DataFrame.from_dict(stats, orient='index')
    





100%|██████████| 100/100 [00:21<00:00,  4.58it/s]


In [5]:
stats

Unnamed: 0,n_tr,n_ac_tr,n_cpdout,n_ac_cpdout,n_bothout,n_ac_bothout,prop_tr,prop_cpdout,prop_bothout
CHEMBL244-Seed0,2081,149,1056,65,146,8,0.633871,0.321657,0.044472
CHEMBL244-Seed1,2215,169,969,50,99,3,0.674688,0.295157,0.030155
CHEMBL244-Seed2,1992,122,1153,87,138,13,0.606762,0.351203,0.042035
CHEMBL204-Seed0,1296,102,762,46,121,9,0.594768,0.349702,0.055530
CHEMBL204-Seed1,1363,121,718,35,98,1,0.625516,0.329509,0.044975
...,...,...,...,...,...,...,...,...,...
CHEMBL203-Seed12,43,1,24,2,2,1,0.623188,0.347826,0.028986
CHEMBL203-Seed31,48,1,17,1,4,2,0.695652,0.246377,0.057971
CHEMBL262-Seed6,41,2,22,1,5,2,0.602941,0.323529,0.073529
CHEMBL262-Seed11,43,3,22,1,3,1,0.632353,0.323529,0.044118


In [6]:
stats.to_csv('./Dataset/Stats/axv.tsv', sep='\t')