In [1]:
import os
import json
import datetime
from pathlib import Path
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors, Lipinski
from rdkit.Chem import DataStructs
from Bio.PDB import PDBParser

import matplotlib.pyplot as plt
import datetime
from tqdm.notebook import tqdm

In [2]:
PWD = '/lambda_stor/homes/songhao.jiang/projects/boltz/scripts/process'

In [3]:
MIN_SEQ_ID = 0.4
DATE_START = datetime.date(2023, 6, 1)
DATE_END = datetime.date(2024, 1, 1)
MAX_RESOLUTION = 4.5
MAX_RESIDUES = 1024
MAX_ENTITIES = 20
VALIDATION_OUTPUT = "filtered_validation_set.json"

with open('/scratch/shjiang/DATA/boltz_1/rcsb_processed_targets/manifest.json', 'r') as f:
    manifest = json.load(f)

In [4]:
def get_chain_clusters(cluster_mapping_file):
    # output: dict{chain_id -> cluster_id}
    cluster_map = {}
    with open(cluster_mapping_file) as f:
        for line in f:
            cluster_id, *chains = line.strip().split()
            for chain in chains:
                cluster_map[chain] = cluster_id
    return cluster_map

In [5]:
cluster_map = get_chain_clusters(f"{PWD}/new_clustering/cluster_prot_cluster.tsv")  # chain_id → cluster_id

In [6]:
training = []
exceed_length = []
not_training = []
for record in tqdm(manifest):
    resolution = record['structure']['resolution']
    release_date = datetime.datetime.strptime(record['structure']['released'], '%Y-%m-%d').date()
    
    if release_date < DATE_START:
        chains = record['chains']
        num_residues = sum([i['num_residues'] for i in record['chains']])
        if num_residues < 5000:
            training.append(record)
        else:
            exceed_length.append(record)
            not_training.append(record)
    else:
        not_training.append(record)

  0%|          | 0/216870 [00:00<?, ?it/s]

In [7]:
training_pdb_ids = set([i['id'] for i in training])
print(len(training),len(training_pdb_ids),len(manifest),len(exceed_length))

197996 197996 216870 6610


In [8]:
# filter_0: date
validation_all = []
# filter_1: resolution
validation_filter_1 = []
# filter_2: sequence identity 
validation_filter_2 = []
for record in tqdm(not_training):
    resolution = record['structure']['resolution']
    release_date = datetime.datetime.strptime(record['structure']['released'], '%Y-%m-%d').date()
    
    if release_date > DATE_START and release_date <= DATE_END:
        validation_all.append(record)
        if resolution < 4.5:
            validation_filter_1.append(record)
            pdb_id = record['id']
            chains = record['chains']
            if any(cluster_map.get(f"{pdb_id}_{chain['chain_name'][0]}", '').split('_')[0] in training_pdb_ids for chain in chains):
                continue
            else:
                validation_filter_2.append(record)
print(len(validation_all), len(validation_filter_1),len(validation_filter_2))

  0%|          | 0/18874 [00:00<?, ?it/s]

8671 8671 3039


In [9]:
# filter_3: currently there exists some pdbs which cannot be downloaded from rtcb
validation_filter_3 = []
for i in validation_filter_2:
    pdb_id = i['id']
    if os.path.exists(f'/scratch/shjiang/DATA/boltz_2/rcsb_processed_targets/{pdb_id}.pkl'):
        validation_filter_3.append(i)
len(validation_filter_3)

3026

In [10]:
import pickle 
from Bio.PDB.MMCIF2Dict import MMCIF2Dict
from boltz.data.const import ligand_exclusion
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
import itertools
import pandas as pd

from rdkit import RDLogger                                                                                                                                                               
RDLogger.DisableLog('rdApp.*') 


def satisfies_lipinski_rule(mol) -> bool:
#     mol = Chem.MolFromSmiles(ligand['smiles'])
    if mol is None:
        return False
    try:
        mw = Descriptors.MolWt(mol)
        logp = Descriptors.MolLogP(mol)
        h_donors = Lipinski.NumHDonors(mol)
        h_acceptors = Lipinski.NumHAcceptors(mol)
        return (
            mw <= 500 and
            logp <= 5 and
            h_donors <= 5 and
            h_acceptors <= 10
        )
    except:
        return False
with open('/scratch/shjiang/DATA/boltz_1/ccd/ccd.pkl', 'rb') as file:
    ccd = pickle.load(file)
valid_ccd = {k:v for k,v in ccd.items() if k not in ligand_exclusion}

In [None]:
# filter 4: check 1) no mol, 2) at least one mol satisfies Lipinski’s Rule of Five
validation_filter_4 = []
mols_for_sim = {}
for i in tqdm(validation_filter_3):
    pdb_id = i['id']
    cif_file = f'/scratch/shjiang/DATA/boltz_1/PDB/{pdb_id}.cif'
    cif = MMCIF2Dict(cif_file)
    try:
        non_poly = cif['_pdbx_entity_nonpoly.comp_id']
#         if not non_poly:
#             validation_filter_4.append(i)
#         else:
        mols = {i:ccd.get(i, None) for i in non_poly if i not in ligand_exclusion}
        mols = {k:v for k,v in mols.items() if v and v.GetNumHeavyAtoms() > 1}
        if not mols:
            # all mols has heavy atoms < 1, no need for similarity search
            validation_filter_4.append(i)
            continue
        
        lipinski = [satisfies_lipinski_rule(mol) for k,mol in mols.items()]
        if any(lipinski):
            validation_filter_4.append(i)
            continue

        # store for futher usage
        mols_for_sim[pdb_id] = non_poly
    except KeyError:
        # does not contain non-poly -> does not contain molecule
        validation_filter_4.append(i)
        
# with open('/lambda_stor/homes/songhao.jiang/projects/boltz/DATA/training_mols/validation_filter_4.pkl', 'wb') as file:
#     pickle.dump(validation_filter_4, file)

In [11]:
with open("/lambda_stor/homes/songhao.jiang/projects/boltz/DATA/training_mols/validation_filter_4.pkl", "rb") as f:
    validation_filter_4 = pickle.load(f)
    
with open("/lambda_stor/homes/songhao.jiang/projects/boltz/DATA/training_mols/mols_for_sim.pkl", "rb") as f:
    mols_for_sim = pickle.load(f)

In [12]:
len(validation_filter_4)

2319

In [13]:
mols_only_for_sim = set().union(*mols_for_sim.values())

In [14]:
from rdkit.Chem import rdFingerprintGenerator
from collections import defaultdict
def pairwise_tanimoto_similarity(target, training, mol_dict, fpSize=1024):
    # Generate Morgan fingerprints for each molecule
    mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2,fpSize=fpSize)
    fps = {}
    for k, mol in mol_dict.items():
        try:
            f_ = mfpgen.GetFingerprint(mol) 
            fps[k] = f_
        except:
            pass
            
    print(f"len(fps): {len(fps)}")
    
    # Compute pairwise similarities
    results = defaultdict(list)
    print("starting....")
    for k1 in tqdm(target):
        for k2 in training:
            try:
                sim = DataStructs.TanimotoSimilarity(fps[k1], fps[k2])
            except:
                sim = 0
            if k1 not in results:
                results[k1] = []
            results[k1].append((k2, sim))
    
    # Convert to DataFrame
    return results

In [15]:
# get mols in training data
# train_mols_for_sim = {}
# for i in tqdm(training):
#     pdb_id = i['id']
#     cif_file = f'/scratch/shjiang/DATA/boltz_1/PDB/{pdb_id}.cif'
#     cif = MMCIF2Dict(cif_file)
#     try:
#         non_poly = cif['_pdbx_entity_nonpoly.comp_id']
#         train_mols_for_sim[pdb_id] = non_poly
#     except:
#         pass
# with open('/lambda_stor/homes/songhao.jiang/projects/boltz/DATA/training_mols/train_mols_for_sim.pkl', 'wb') as file:
#     pickle.dump(train_mols_for_sim, file)

with open("/lambda_stor/homes/songhao.jiang/projects/boltz/DATA/training_mols/train_mols_for_sim.pkl", "rb") as f:
    obj = pickle.load(f)
    
training_mols = []
for k,v in obj.items():
    training_mols.extend(v)
training_mols = list(set(training_mols))
len(training_mols)

36381

In [None]:
# get tonimoto similarity between mols in validation and mols in training

# results = pairwise_tanimoto_similarity(mols_only_for_sim, training_mols, valid_ccd, 2048)

# with open('/lambda_stor/homes/songhao.jiang/projects/boltz/DATA/training_mols/sim_results.pkl', 'wb') as file:
#     pickle.dump(results, file)

In [16]:
with open("/lambda_stor/homes/songhao.jiang/projects/boltz/DATA/training_mols/sim_results.pkl", "rb") as f:
    results = pickle.load(f)

In [17]:
validation_filter_4_sim = []
for pdb_id, list_mols in tqdm(mols_for_sim.items()):
    for mol in list_mols:
        sims = results[mol]
        sims = [i[1] < 0.85 for i in sims]
        if any(sims):
            validation_filter_4_sim.append(pdb_id)
            break

  0%|          | 0/707 [00:00<?, ?it/s]

In [18]:
len(validation_filter_4_sim), len(mols_for_sim), len(validation_filter_4)

(707, 707, 2319)

In [19]:
len(validation_filter_4_sim) + len(validation_filter_4), len(validation_filter_3)

(3026, 3026)

In [20]:
validation_before_further = []
for i in validation_filter_4:
    validation_before_further.append(i['id'])
for i in validation_filter_4_sim:
    validation_before_further.append(i)

In [21]:
len(validation_before_further)

3026

In [22]:
validation_filter_5 = []
for record in tqdm(manifest):
    if record['id'] in validation_before_further:
        chains = record['chains']
        
        num_entites = len(chains)
        num_residues = 0
        chain_type = []
        for chain in chains:
            num_residues += chain['num_residues']
            chain_type.append(chain['mol_type'])
        if 1 in chain_type or 2 in chain_type:
            validation_filter_5.append(record['id'])
        else:
            if num_residues > 1024:
                continue
            if num_entites > 20:
                continue
        validation_filter_5.append(record['id'])

  0%|          | 0/216870 [00:00<?, ?it/s]

In [23]:
len(validation_filter_5)

2344