In [1]:
import numpy as np
import pandas as pd
import os, glob, sys
from collections import defaultdict
from typing import Union, Optional, List
from tqdm import tqdm
from Bio.PDB import Select
import networkx as nx 
import openmm as mm
import openmm.app as app
import openmm.unit as unit
from scipy.spatial.distance import cdist

os.environ['PATH'] = f'/global/common/software/nersc/pe/conda/24.1.0/Miniconda3-py311_23.11.0-2/condabin/:{os.environ["PATH"]}'
from openff.toolkit import Molecule, Topology
from openmmforcefields.generators import GAFFTemplateGenerator, SMIRNOFFTemplateGenerator
from rdkit import Chem

In [6]:
def to_quantity(ndarray):
    value = [mm.Vec3(float(arr[0]), float(arr[1]), float(arr[2])) for arr in ndarray]
    quantity = unit.Quantity(value, unit=unit.nanometers)
    return quantity

for sdf in tqdm(list(glob.glob('../raw_data_pdbbind_sm_ver2/*/*/*_ligand_fixed.sdf'))):
    mol = Chem.SDMolSupplier(sdf, removeHs=False)[0]
    ligand_missing_atoms = [int(x) for x in mol.GetPropsAsDict().get('Missing Atoms', '')]
    
    # Build ligand force field
    off_mol = Molecule.from_rdkit(mol, allow_undefined_stereo=True, hydrogens_are_explicit=True)
    off_mol.assign_partial_charges('gasteiger')
    ligand_pos = to_quantity(mol.GetConformer().GetPositions() / 10)
    ligand_top = Topology.from_molecules(off_mol).to_openmm()

    ff = app.ForceField('amber14-all.xml')
    ff.registerTemplateGenerator(SMIRNOFFTemplateGenerator(molecules=[off_mol]).generator)

    system = ff.createSystem(ligand_top, nonbondedMethod=app.CutoffNonPeriodic, nonbondedCutoff=0.5 * unit.nanometers, constraints=None)
   

  0%|          | 24/27062 [00:17<5:23:29,  1.39it/s]


KeyboardInterrupt: 

In [4]:
def parse_headers(pdb_file):
    headers = []
    with open(pdb_file) as f:
        for line in f:
            if line.startswith('ATOM') or line.startswith('HETATM'):
                break
            headers.append(line)
    return headers


def parse_missing_residues_and_atoms(headers: List[str]):
    missing_residues = []
    missing_atoms = defaultdict(list)
    read_465, read_470 = False, False
    for line in headers:
        line = line.strip()
        if line.startswith('REMARK 465   M RES C SSSEQI'):
            read_465 = True
            continue
        if line.startswith('REMARK 470   M RES CSSEQI  ATOMS'):
            read_470 = True
            continue

        if read_470:
            if not line.startswith('REMARK 470'):
                read_470 = False
            else:
                content = line.split()
                resname, chain, resid, atoms = content[2], content[3], content[4], content[5:]
                missing_atoms[chain, resid, resname] = atoms

        if read_465:
            if not line.startswith('REMARK 465'):
                read_465 = False
            else:
                content = line.split()
                resname, chain, resid = content[2], content[3], content[4]
                missing_residues.append((chain, resid, resname))
    return missing_residues, missing_atoms


In [5]:
import shutil

def to_quantity(ndarray):
    value = [mm.Vec3(float(arr[0]), float(arr[1]), float(arr[2])) for arr in ndarray]
    quantity = unit.Quantity(value, unit=unit.nanometers)
    return quantity


import openmm.unit as unit
import numpy as np


def modify_bond_force_constants(system, context, positions):
    """
    Modifies the force constant of bonds in an OpenMM system. If the distance 
    between two atoms in the bond is greater than 3 angstroms, the force constant is set to zero.

    Args:
    - system: OpenMM system object
    - positions: OpenMM positions object (list of Vec3 vectors of atomic positions)
    
    Returns:
    - None. The system is modified in place.
    """
    # Get the HarmonicBondForce from the system
    for force_index in range(system.getNumForces()):
        force = system.getForce(force_index)
        if isinstance(force, mm.HarmonicBondForce):
            # Loop over all bonds in the force
            for bond_index in range(force.getNumBonds()):
                # Get the atoms involved in the bond and the bond parameters
                particle1, particle2, length, k = force.getBondParameters(bond_index)

                # Calculate the distance between the two atoms in the bond
                pos1 = np.array(positions[particle1].value_in_unit(unit.nanometer))
                pos2 = np.array(positions[particle2].value_in_unit(unit.nanometer))
                distance = np.linalg.norm(pos1 - pos2) * 10.0  # Convert nanometers to angstroms

                # If the distance is greater than 3 angstroms, set the force constant to zero
                if distance > 3.0:
                    print(particle1, particle2, distance)
                    force.setBondParameters(bond_index, particle1, particle2, length, 0.0)
            
            # After modifying parameters, we need to reinitialize the force
            force.updateParametersInContext(context)



def refinePositionsWithLigand(in_protein, in_ligand, out_protein=None, out_ligand=None, extra_ffs=None, num_opt_cycles=3):
    # if out_protein is None:
    #     shutil.copyfile(in_protein, in_protein + '.backup')
    #     out_protein = in_protein
    #     in_protein = in_protein + '.backup'
    # if out_ligand is None:
    #     shutil.copyfile(in_ligand, in_ligand + '.backup')
    #     out_ligand = in_ligand
    #     in_ligand = in_ligand + '.backup'
    
    headers = parse_headers(in_protein)
    missing_residues, missing_atoms = parse_missing_residues_and_atoms(headers)

    mol = Chem.SDMolSupplier(in_ligand, removeHs=False)[0]
    ligand_missing_atoms = [int(x) for x in mol.GetPropsAsDict().get('Missing Atoms', '')]
    
    # Build ligand force field
    off_mol = Molecule.from_rdkit(mol)
    off_mol.assign_partial_charges('gasteiger')
    ligand_pos = to_quantity(mol.GetConformer().GetPositions() / 10)
    ligand_top = Topology.from_molecules(off_mol).to_openmm()

    extra_ffs = extra_ffs if extra_ffs is not None else []
    ff = app.ForceField('amber14-all.xml', *extra_ffs)
    ff.registerTemplateGenerator(GAFFTemplateGenerator(molecules=[off_mol], forcefield='gaff-2.11').generator)

    pdb = app.PDBFile(in_protein)
    modeller = app.Modeller(pdb.topology, pdb.positions)
    modeller.add(ligand_top, ligand_pos)

    top = modeller.getTopology()
    pos = modeller.getPositions()

    for residue in top.residues():
        if residue.index == top.getNumResidues() - 1:
            print(residue.chain.id, residue.name)
            for atom in residue.atoms():
                print(atom.name)
    return 

    system = ff.createSystem(top, nonbondedMethod=app.CutoffNonPeriodic, nonbondedCutoff=0.5 * unit.nanometers, constraints=None)
   
    constrIndices = []
    for residue in top.residues():
        if residue.index == top.getNumResidues() - 1:
            is_missing_residue = False
            missing_atoms_residue = [atom.name for i, atom in enumerate(residue.atoms()) if i in ligand_missing_atoms]
        else:
            resinfo = (residue.chain.id, f'{residue.id}{residue.insertionCode}'.strip(), residue.name)
            missing_atoms_residue = missing_atoms.get(resinfo, [])
            is_missing_residue = resinfo in missing_residues
        
        for atom in residue.atoms():
            if (atom.element.symbol == 'H') or is_missing_residue or (atom.name in missing_atoms_residue):
                continue
            constrIndices.append(atom.index)

    posConstr = np.array([[vec.x, vec.y, vec.z] for vec in pos])[constrIndices]
    masses = [system.getParticleMass(i) for i in range(system.getNumParticles())]

    temperature = 100
    integrator = mm.LangevinIntegrator(temperature*unit.kelvin, 1/unit.picosecond, 0.0005*unit.picoseconds)
    simulation = app.Simulation(top, system, integrator)
    modify_bond_force_constants(system, simulation.context, pos)

    # Do optimization
    # In one cycle: full opt -> reset posisitons of constr. atoms -> constr. opt
    # This will relax steric clashes
    for _ in range(num_opt_cycles):
        simulation.context.setPositions(pos)
        simulation.minimizeEnergy(tolerance=100, maxIterations=1000)
        simulation.context.setVelocitiesToTemperature(temperature)
        simulation.step(50)
        posTmp = simulation.context.getState(getPositions=True).getPositions(asNumpy=True)._value
        posTmp[constrIndices] = posConstr

        for i in constrIndices:
            system.setParticleMass(i, 0.0*unit.amu)

        simulation.context.setPositions(posTmp)
        simulation.minimizeEnergy(tolerance=100, maxIterations=1000)

        pos = simulation.context.getState(getPositions=True).getPositions()
        for i in range(system.getNumParticles()):
            system.setParticleMass(i, masses[i])
    
    return
    # output 
    fp = open(out_protein, 'w')
    fp.write(''.join(headers))
    app.PDBFile.writeModel(pdb.topology, pos[:pdb.topology.getNumAtoms()], file=fp, keepIds=True)
    app.PDBFile.writeFooter(pdb.topology, file=fp)
    fp.close()

    writer = Chem.SDWriter(out_ligand)
    for i in range(mol.GetNumAtoms()):
        vec = pos[pdb.topology.getNumAtoms() + i]
        mol.GetConformer().SetAtomPosition(i, [vec.x * 10, vec.y * 10, vec.z * 10])
    writer.write(mol)


refinePositionsWithLigand(
    '/pscratch/sd/e/eric6/CLP-PDBBind/edge_cases/5wg6/5wg6_A9G_A_9009/5wg6_A9G_A_9009_protein_fixed.pdb',
    '/pscratch/sd/e/eric6/CLP-PDBBind/edge_cases/5wg6/5wg6_A9G_A_9009/5wg6_A9G_A_9009_ligand_fixed.sdf'
)

X UNK
C1x
C2x
C3x
C4x
C5x
C6x
C7x
C8x
C9x
C10x
C11x
C12x
C13x
C14x
C15x
C16x
C17x
C18x
C19x
C20x
C21x
C22x
C23x
C24x
C25x
C26x
C27x
C28x
C29x
C30x
C31x
N1x
N2x
N3x
N4x
N5x
N6x
O1x
O2x
H1x
H2x
H3x
H4x
H5x
H6x
H7x
H8x
H9x
H10x
H11x
H12x
H13x
H14x
H15x
H16x
H17x
H18x
H19x
H20x
H21x
H22x
H23x
H24x
H25x
H26x
H27x
H28x
H29x
H30x
H31x
H32x
H33x
H34x
H35x
H36x
H37x
H38x
H39x


In [24]:

# forces = { force.__class__.__name__ : force for force in system.getForces() }
# nbforce = forces['NonbondedForce']
# nbparams = [nbforce.getParticleParameters(i) for i in range(nbforce.getNumParticles())]

# torforce = forces['PeriodicTorsionForce']
# torparams = [torforce.getTorsionParameters(i) for i in range(torforce.getNumTorsions())]