In [17]:
import random
from operator import itemgetter
from collections import Counter,defaultdict, namedtuple
from functools import partial
from itertools import zip_longest, groupby, chain
import mdtraj as md
import numpy as np
from util import *
srcPathBase="/Tyk2/Data"  # Path for mol2,frcmod, lib folders
destPathBase="/Tyk2/6-competitor" # Set the path 

# Critera for generating ligand heavy atoms centric contacts:
<ol>
<li> Compute ligand-protein contacts within using cut-off 0.38nm; </li>
<li> Sort and group all contacts using ligand atoms indices; </li>
<li> Randomly select at most one contacts per ligand heavy atoms and save them. </li>
</ol>

In [18]:
def formatContact(top,idxPair):
    atom_i, atom_j = map(top.atom,idxPair)
    resid_i_idx, atom_i_name = atom_i.residue.index + 1, atom_i.name
    resid_j_idx, atom_j_name = atom_j.residue.index + 1, atom_j.name
    return str(resid_i_idx),atom_i_name,str(resid_j_idx),atom_j_name

def filterContacts(ligand_contact_dict, filter_policy):
    """
    Parameters:
    -----------
    ligand_contact_dict: dict of lists containing contacts for each ligand.
                         key: ligand ID; value: list of contact pair (prot_atom_index, lig_atom_index)
    filter_policy: a namedtuple.
                    max_contact_per_field: int; NOTE: field can be protein residue or ligand atoms.
                    groupby_field: int; which field to use in groupby.
                    random_seed: int; random seed to pick contacts.
    
    Output:
    -------
    filtered_contacts: dict of lists containing filtered contacts for each ligand.
    """
    required_fields_in_filter_policy = set(["random_seed", "groupby_field", "max_contact_per_field"])
    poorly_defined_field = required_fields_in_filter_policy - set(filter_policy._fields) 
    if poorly_defined_field:
        raise ValueError(f" {poorly_defined_field} in filter_policy is incorrectly defined.")
    if not isinstance(filter_policy.random_seed, int):
        raise ValueError("random seed must be an integer.")
    if not callable(filter_policy.groupby_field):
        raise ValueError("groupby_field must be a callable.")
    if not isinstance(filter_policy.max_contact_per_field, int):
        raise ValueError("max_contact_per_target must be an integer.")
        
    reduced_contacts_dict = defaultdict(list)
    MAX_CONTACT_PER_FIELD = filter_policy.max_contact_per_field
    groupby_field = filter_policy.groupby_field
    rand_seed = filter_policy.random_seed
    
    for mol_id in ligand_contact_dict.keys():
        contacts_per_ligand = ligand_contact_dict[mol_id]
        if contacts_per_ligand: # don't group empty ones.
            for field, contacts in groupby(contacts_per_ligand,key=groupby_field):
                contact_list = list(contacts)
                RESTRAINT_PER_FIELD = min(len(contact_list), MAX_CONTACT_PER_FIELD)
                if rand_seed is not None:
                    random.seed(rand_seed)
                picked_restraint = random.sample(contact_list,RESTRAINT_PER_FIELD)
                reduced_contacts_dict[mol_id].extend(picked_restraint)
    return reduced_contacts_dict

# Compute bound ligand contacts for a new bracket

In [19]:
brackets = ["L01_L02_L03_L14_L15_L16",]

In [20]:
CONTACT_CUT_OFF = 0.38
SAVE_CONTACTS = True
folderTemplate = "Bound_{mol_id}"
for aBracket in brackets[:]:
    mol_info = {}
    for ligIdx in aBracket.split("_"):
        mol_info[ligIdx] = ligIdx
        folder_name="_".join(sorted(list(mol_info.keys())))
    print(folder_name)
    templateContacts = defaultdict(list)
    for mol_id in mol_info.keys():
        path = folderTemplate.format(mol_id=mol_id)
        with cd(f"{destPathBase}/{folder_name}/1.Prepare/{path}"):
            # Use initial pose to compute contacts
            t = md.load(f"B{mol_id}.inpcrd",top=f"B{mol_id}.prmtop") # use initial pose 
            top = t.topology
            convertIdxToName = partial(formatContact,top)
            
            lig_ha = t.topology.select(f"resname {mol_info[mol_id]} and symbol != 'H'") # ligand non-hydrogen atoms
            prot_ha = t.topology.select("protein and symbol != 'H'")   

            for atom in prot_ha:
                prot_idx = t.topology.atom(atom).residue.index + 1 # mdtraj is 0-based
                atom = np.array([atom])
                neighbor = md.compute_neighbors(t,CONTACT_CUT_OFF,atom,haystack_indices=lig_ha,periodic=True)[0]
                #print(f"There are {len(neighbor)} contacts")
                if neighbor.size > 0:
                    templateContacts[mol_id].extend(zip(atom.repeat(len(neighbor)),neighbor))
    
    filter_policy = namedtuple("filter_policy",["max_contact_per_field", "groupby_field","random_seed"])
    filter_policy.max_contact_per_field = 1
    filter_policy.groupby_field = itemgetter(-1) #lambda a_contact: a_contact[-1] # group by ligand atom name.
    filter_policy.random_seed = 1234
    
    templateContacts_transformed = defaultdict(list)
    for key, val in templateContacts.items():
        templateContacts_transformed[key] = sorted(list(map(convertIdxToName, val)), 
                                                   key=filter_policy.groupby_field)
    reduced_contacts_dict = filterContacts(templateContacts_transformed, filter_policy)
            
    min_amount = min(list(map(len,reduced_contacts_dict.values())))
    print("_".join(mol_info.keys()),min_amount)
    
    if SAVE_CONTACTS:    
        for mol_id in mol_info.keys():    
            path = folderTemplate.format(mol_id=mol_id)
            with cd(f"{destPathBase}/{folder_name}/1.Prepare/{path}"):
                with open(f"B{mol_id}_heavyatom_ligand_centric.txt","w") as fl:
                    random.seed(filter_policy.random_seed)
                    picked_restraint = random.sample(reduced_contacts_dict[mol_id],\
                                                     min_amount)
                    for rest in picked_restraint:
                        fl.write(" ".join(rest)+"\n")







 