In [1]:
%config Completer.use_jedi = False
import numpy as np
import pandas as pd
from rdkit import Chem

from rdkit.Chem.Draw import IPythonConsole

In [2]:
lig = Chem.MolFromMolFile('/data/hookbill/hadfield/syntheticVS/data/hydrophobic_test/sdf/ligands/lig1.sdf')
pharm = Chem.MolFromMolFile('/data/hookbill/hadfield/syntheticVS/data/hydrophobic_test/sdf/pharmacophores/pharm1.sdf')

In [None]:
lig.RemoveAllConformers()

In [None]:
pharm

In [None]:
lig

In [None]:
def mol_with_atom_index(mol):
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(atom.GetIdx())
    return mol

mol_with_atom_index(lig)

In [24]:
interaction_score, contrib_df, lig_gt, pharm_gt = assign_mol_label(lig, pharm, hydrophobic = True, return_contrib_df=True)

In [28]:
pharm_gt

Unnamed: 0,pharm_atom_idx,x,y,z,contribution
0,0,4.6732,-4.4335,-3.405,0.502441
1,1,0.5436,4.0742,3.1978,2.742854
2,2,9.2665,-0.1294,5.2382,2.254594
3,3,-4.9288,0.5705,-3.2017,0.509926
4,4,-4.9122,1.9939,6.4091,0.339497
5,5,11.0671,-3.8841,-3.1395,0.028887
6,6,-2.0647,-3.8291,3.2919,0.310612


In [29]:
lig_gt

Unnamed: 0,lig_atom_idx,x,y,z,contribution
0,0,-3.7675,1.6316,-1.1564,0.0
1,1,-3.4986,1.0598,0.2383,0.21816
2,2,-1.9948,0.9044,0.5085,0.0
3,3,-1.3602,-0.2576,-0.2709,0.0
4,4,0.0386,-0.4437,0.1313,0.502441
5,5,1.043,0.432,-0.2856,0.0
6,6,0.9565,1.3088,-1.3681,0.0
7,7,1.9542,2.1214,-1.7974,1.804234
8,8,3.1295,2.0351,-1.135,0.0
9,9,3.3688,1.1512,-0.0912,0.043401


In [5]:
score, contrib_df = assign_mol_label(lig, pharm, hydrophobic = True, return_contrib_df=True)

In [8]:
atom_contibutions = contrib_df[['lig_atom_idx', 'contribution']].groupby('lig_atom_idx').aggregate(sum)

In [17]:
atom_contibutions

Unnamed: 0_level_0,contribution
lig_atom_idx,Unnamed: 1_level_1
1,0.21816
4,0.502441
7,1.804234
9,0.043401
11,1.93225
12,2.081501
18,0.106822


In [18]:
for idx, row in atom_contibutions.iterrows():
    print(idx)

1
4
7
9
11
12
18


In [12]:
contrib_df[['lig_atom_idx', 'ligand_pos']].drop_duplicates('lig_atom_idx')

Unnamed: 0,lig_atom_idx,ligand_pos
0,4,"[0.0386, -0.4437, 0.1313]"
1,7,"[1.9542, 2.1214, -1.7974]"
2,11,"[5.4036, 2.14, 0.8995]"
3,12,"[6.5869, 1.693, 1.4366]"
10,1,"[-3.4986, 1.0598, 0.2383]"
11,9,"[3.3688, 1.1512, -0.0912]"
12,18,"[-3.6117, -1.4436, -0.3037]"


In [19]:
atom_indices = []
atom_positions = []

for idx, atom in enumerate(lig.GetAtoms()):
    atom_indices.append(atom.GetIdx())
    atom_positions.append(np.array(lig.GetConformer().GetAtomPosition(atom.GetIdx())))

    
lig_df = pd.DataFrame({'lig_atom_idx':atom_indices, 'x':[x[0] for x in atom_positions], 
                      'y':[y[1] for y in atom_positions], 'z': [z[2] for z in atom_positions]})



contribution = []

for idx, row in lig_df.iterrows():
    
    c = 0
    
    for jdx, sow in atom_contibutions.iterrows():
        
        if row['lig_atom_idx'] == jdx:
            contribution.append(sow['contribution'])
            c = 1 #i.e. this atom makes a contribution to the score
            
    if c == 0:
        contribution.append(0)
        
lig_df['contribution'] = contribution

lig_df

Unnamed: 0,lig_atom_idx,x,y,z,contribution
0,0,-3.7675,1.6316,-1.1564,0.0
1,1,-3.4986,1.0598,0.2383,0.21816
2,2,-1.9948,0.9044,0.5085,0.0
3,3,-1.3602,-0.2576,-0.2709,0.0
4,4,0.0386,-0.4437,0.1313,0.502441
5,5,1.043,0.432,-0.2856,0.0
6,6,0.9565,1.3088,-1.3681,0.0
7,7,1.9542,2.1214,-1.7974,1.804234
8,8,3.1295,2.0351,-1.135,0.0
9,9,3.3688,1.1512,-0.0912,0.043401


In [None]:
def vec_to_vec_dist(p1, p2):
    return np.linalg.norm(p1 - p2)




def format_contrib_df(lig, contrib_df):
    
    atom_indices = []
    for idx, row in contrib_df.iterrows():
        
        for atom in lig.GetAtoms():
            
            if vec_to_vec_dist(row['ligand_pos'], np.array(lig.GetConformer().GetAtomPosition(atom.GetIdx()))) < 0.05:
                atom_indices.append(atom.GetIdx())
                
    contrib_df['lig_atom_idx'] = atom_indices
    
    return contrib_df
    

def get_pharm_indices(mol):
    pharms = ['Hydrophobe', 'Donor', 'Acceptor', 'LumpedHydrophobe']
    pharms_idx_dict = defaultdict(list)
    if mol.GetNumAtoms() < 1:
        return pharms_idx_dict

    #mol.AddConformer(mol.GetConformer())
    feats = FACTORY.GetFeaturesForMol(mol)
    for feat in feats:
        if feat.GetFamily() in pharms:
            pharms_idx_dict[feat.GetFamily()] += list(feat.GetAtomIds())

    return pharms_idx_dict

In [None]:
get_pharm_indices(lig)

In [None]:
format_contrib_df(lig, contrib_df)[['lig_atom_idx', 'contribution']].groupby('lig_atom_idx').aggregate(sum)

In [None]:
contrib_df[['ligand_pos', 'contribution']].groupby(by = ['ligand_pos']).aggregate(sum)

In [23]:
import argparse
import faulthandler
from collections import defaultdict
from pathlib import Path

import numpy as np
from pathos.multiprocessing import ProcessingPool as Pool
#from point_vs.utils import expand_path, save_yaml
from rdkit import RDConfig, Chem
from rdkit.Chem import ChemicalFeatures
from scipy.stats import gamma


FACTORY = ChemicalFeatures.BuildFeatureFactory(
    str(Path(RDConfig.RDDataDir, 'BaseFeatures.fdef')))

def vec_to_vec_dist(p1, p2):
    return np.linalg.norm(p1 - p2)

def get_atomwise_contributions(mol, contrib_df, ligand_atoms = True):
    
    if ligand_atoms:
        
        atom_indices = []
        for idx, row in contrib_df.iterrows():

            for atom in mol.GetAtoms():

                if vec_to_vec_dist(row['ligand_pos'], np.array(mol.GetConformer().GetAtomPosition(atom.GetIdx()))) < 0.05:
                    atom_indices.append(atom.GetIdx())
                    
        contrib_df['lig_atom_idx'] = atom_indices
        
        atom_contibutions = contrib_df[['lig_atom_idx', 'contribution']].groupby('lig_atom_idx').aggregate(sum)
            
        atom_indices = []
        atom_positions = []

        for idx, atom in enumerate(mol.GetAtoms()):
            atom_indices.append(atom.GetIdx())
            atom_positions.append(np.array(mol.GetConformer().GetAtomPosition(atom.GetIdx())))


        atomwise_df = pd.DataFrame({'lig_atom_idx':atom_indices, 'x':[x[0] for x in atom_positions], 
                              'y':[y[1] for y in atom_positions], 'z': [z[2] for z in atom_positions]})

        contribution = []
        for idx, row in atomwise_df.iterrows():
            c = 0
            for jdx, sow in atom_contibutions.iterrows():
                if row['lig_atom_idx'] == jdx:
                    contribution.append(sow['contribution'])
                    c = 1 #i.e. this atom makes a contribution to the score

            if c == 0:
                contribution.append(0)

        atomwise_df['contribution'] = contribution
        
        return contrib_df, atomwise_df
        
        
    else:
        
        atom_indices = []
        for idx, row in contrib_df.iterrows():

            for atom in mol.GetAtoms():

                if vec_to_vec_dist(row['pharm_pos'], np.array(mol.GetConformer().GetAtomPosition(atom.GetIdx()))) < 0.05:
                    atom_indices.append(atom.GetIdx())
                    
        contrib_df['pharm_atom_idx'] = atom_indices
        
        atom_contibutions = contrib_df[['pharm_atom_idx', 'contribution']].groupby('pharm_atom_idx').aggregate(sum)
            
        atom_indices = []
        atom_positions = []

        for idx, atom in enumerate(mol.GetAtoms()):
            atom_indices.append(atom.GetIdx())
            atom_positions.append(np.array(mol.GetConformer().GetAtomPosition(atom.GetIdx())))


        atomwise_df = pd.DataFrame({'pharm_atom_idx':atom_indices, 'x':[x[0] for x in atom_positions], 
                              'y':[y[1] for y in atom_positions], 'z': [z[2] for z in atom_positions]})

        contribution = []
        for idx, row in atomwise_df.iterrows():
            c = 0
            for jdx, sow in atom_contibutions.iterrows():
                if row['pharm_atom_idx'] == jdx:
                    contribution.append(sow['contribution'])
                    c = 1 #i.e. this atom makes a contribution to the score

            if c == 0:
                contribution.append(0)

        atomwise_df['contribution'] = contribution
        
        return contrib_df, atomwise_df
    





def assign_mol_label(ligand, pharm_mol, threshold=3.5, fname_idx=None, hydrophobic = False, return_contrib_df = False):
    """Assign the labels 0 or 1 to atoms in the pharm/ligand molecules.

    If there is a receptor pharmacophore within the threshold of a matching
    ligand pharmacophore, the class of the atom is 1. If not, it is zero.

    Arguments:
        ligand: RDKit mol object (ligand molecule)
        pharm_mol: RDKit mol object (fake receptor pharmacophores)
        threshold: cutoff for interaction distance which is considered an active
            interaction
        fname_idx: index of ligand and pharm_mol sdf in directory (if supplied)
    """

    

    def get_pharm_indices(mol):
        pharms = ['Hydrophobe', 'Donor', 'Acceptor', 'LumpedHydrophobe']
        pharms_idx_dict = defaultdict(list)
        if mol.GetNumAtoms() < 1:
            return pharms_idx_dict

        mol.AddConformer(mol.GetConformer())
        feats = FACTORY.GetFeaturesForMol(mol)
        for feat in feats:
            if feat.GetFamily() in pharms:
                pharms_idx_dict[feat.GetFamily()] += list(feat.GetAtomIds())

        return pharms_idx_dict

    if pharm_mol is None:
        if fname_idx is not None:
            return fname_idx, []
        return []

    ligand_pharms_indices = get_pharm_indices(ligand)
    ligand_pharms_positions = defaultdict(list)
    # get positions
    for k in ligand_pharms_indices.keys():
        for idx in ligand_pharms_indices[k]:
            ligand_pharms_positions[k].append(
                np.array(ligand.GetConformer().GetAtomPosition(idx)))

    """
    positive_coords = PositionLookup()
    min_distances_to_pharms = []
    """
    
    if not hydrophobic:
    
        positive_coords = []
        for idx, atom in enumerate(pharm_mol.GetAtoms()):
            atom_pos = np.array(
                pharm_mol.GetConformer().GetAtomPosition(idx))
            for atomic_symbol, ligand_pharm_positions in zip(
                    ('C', 'O', 'N'), (ligand_pharms_positions['Hydrophobe'] +
                                      ligand_pharms_positions['LumpedHydrophobe'],
                                      ligand_pharms_positions['Acceptor'],
                                      ligand_pharms_positions['Donor'])):
                if atom.GetSymbol() == atomic_symbol:
                    for ligand_pharm_position in ligand_pharm_positions:
                        dist = vec_to_vec_dist(ligand_pharm_position, atom_pos)
                        if dist < threshold:
                            positive_coords.append(ligand_pharm_position)
                            positive_coords.append(atom_pos)
        if fname_idx is not None:
            return fname_idx, positive_coords
        return positive_coords
    else:
        
        interaction_score = 0
        
        if return_contrib_df:
            p_positions = []
            l_positions = []
            pairwise_contribution = []
            
            
        
        
        for idx, atom in enumerate(pharm_mol.GetAtoms()):
            atom_pos = np.array(
                pharm_mol.GetConformer().GetAtomPosition(idx))
            for atomic_symbol, ligand_pharm_positions in zip(
                    #('C', 'O', 'N'), (ligand_pharms_positions['Hydrophobe'] +
                                     # ligand_pharms_positions['LumpedHydrophobe'],
                                     # ligand_pharms_positions['Acceptor'],
                                     # ligand_pharms_positions['Donor'])):
                    ('C', 'O', 'N'), (ligand_pharms_positions['Hydrophobe'],
                                      ligand_pharms_positions['Acceptor'],
                                      ligand_pharms_positions['Donor'])):
                
                if atom.GetSymbol() == atomic_symbol:
                    for ligand_pharm_position in ligand_pharm_positions:
                        
                        if atomic_symbol == 'C':   
                            
                            ic = interaction_contribution(ligand_pharm_position, atom_pos, hydrophobic = True)
                            interaction_score += ic
                            
                            if return_contrib_df:
                                p_positions.append(atom_pos)
                                l_positions.append(ligand_pharm_position)
                                pairwise_contribution.append(ic)
                            
                        else:
                            ic = interaction_contribution(ligand_pharm_position, atom_pos, hydrophobic = False)
                            interaction_score += ic
                            
                            if return_contrib_df:
                                p_positions.append(atom_pos)
                                l_positions.append(ligand_pharm_position)
                                pairwise_contribution.append(ic)
                                
                        
        if return_contrib_df:
            contrib_df = pd.DataFrame({'pharm_pos':p_positions, 'ligand_pos':l_positions, 'contribution':pairwise_contribution})
            
            
            
            contrib_df, lig_gt = get_atomwise_contributions(ligand, contrib_df, ligand_atoms = True)
            contrib_df, pharm_gt = get_atomwise_contributions(pharm_mol, contrib_df, ligand_atoms = False)

            
            
            return interaction_score, contrib_df, lig_gt, pharm_gt           
        else:
            return interaction_score
    # noinspection PyUnreachableCode
    """
                    if ligand_pharm_position not in positive_coords:
                        positive_coords.append(ligand_pharm_position)
                        min_distances_to_pharms.append(dist)
                    else:
                        min_distances_to_pharms[
                            positive_coords.index(ligand_pharm_position)] 
                            = min(
                            dist, min_distances_to_pharms[
                            positive_coords.index(
                                ligand_pharm_position)])
                    if atom_pos not in positive_coords:
                        positive_coords.append(ligand_pharm_position)
                        min_distances_to_pharms.append(dist)
                    else:
                        min_distances_to_pharms[
                            positive_coords.index(atom_pos)] = min(
                            dist, min_distances_to_pharms[
                            positive_coords.index(
                                atom_pos)])
    if fname_idx is not None:
        return fname_idx, positive_coords, min_distances_to_pharms
    return positive_coords, min_distances_to_pharms    
    """

def interaction_contribution(lig_position, prot_position, hydrophobic = False):
    #For the case where we include hydrophobic pharmacophores, this function assigns
    #a score to each interaction
    #(if the cumulative interaction score is greater than a threshold then we will classify as active, otherwise we will classify as inactive)
    distance = np.linalg.norm(lig_position - prot_position)
    return gamma_score(distance, hydrophobic = hydrophobic)
    #return threshold_score(distance, hydrophobic=hydrophobic)
    
def gamma_score(x, a = 4, hydrophobic = False):
    
    if hydrophobic:
        return 3*gamma.pdf(np.abs(x), a)
    else:
        return 10*gamma.pdf(np.abs(x), a)

def threshold_score(x, hydrophobic = False):
    if hydrophobic:
        if x < 4:
            return 1
        else:
            return 0
    else:
        if x < 4:
            return 10
        else: 
            return 0


def label_dataset(root, threshold):
    """Use multiprocssing to post-facto label atoms and mols in sdf dataset."""
    faulthandler.enable()
    root = expand_path(root)
    mol_labels = {}
    coords_with_positive_label = {}
    indices, pharm_mols, lig_mols = [], [], []
    for lig_sdf in Path(root, 'ligands').glob('*.sdf'):
        idx = int(Path(lig_sdf.name).stem.split('lig')[-1])
        pharm_sdf = str(root / 'pharmacophores' / 'pharm{}.sdf'.format(idx))
        lig_mols.append(Chem.SDMolSupplier(str(lig_sdf))[0])
        pharm_mols.append(Chem.SDMolSupplier(str(pharm_sdf))[0])
        indices.append(idx)
    print('SDFs loaded.')
    thresholds = [threshold] * len(indices)
    results = Pool().map(
        assign_mol_label, lig_mols, pharm_mols, thresholds, indices)
    print('SDFs processed.')
    for res in results:
        idx = res[0]
        positive_coords = res[1]
        coords_with_positive_label[idx] = [
            [float(i) for i in coords] for coords in positive_coords]
        mol_labels[idx] = int(len(positive_coords) > 0)
    print('Results constructed.')
    return coords_with_positive_label, mol_labels
