# Qualaria (Ctd)

This code is the second part of the Qualaria ligand vetting flow. This second part focuses on calculating the ground state energies of the selected ligands generated for each pockets, in the first notebook.

We ran this code with Python `3.11.13`, on Google Colab and Kaggle (as a guide for which Python version to use). It could run with Python versions below that, but we did not test it.)

## 3. Evaluating the Ground State Energies

In [23]:
!pip uninstall -y qiskit qiskit-nature qiskit-nature-pyscf pylatexenc --quiet
!pip install "qiskit==1.4.3" qiskit-aer qiskit-nature qiskit-nature-pyscf pylatexenc --quiet

In [24]:
import os
import time
import json

from pyscf import gto, scf, qmmm
from pyscf.solvent import ddcosmo

from qiskit_nature.units import DistanceUnit
from qiskit_nature.second_q.drivers import PySCFDriver
from qiskit_nature.second_q.problems import ElectronicBasis
from qiskit_nature.second_q.transformers import BasisTransformer
from qiskit_nature.second_q.transformers import ActiveSpaceTransformer
from qiskit_nature.second_q.mappers import JordanWignerMapper
from qiskit_nature.second_q.circuit.library import HartreeFock, UCCSD
from qiskit_nature.second_q.algorithms.initial_points import HFInitialPoint
from qiskit_algorithms import VQE
# from qiskit_algorithms.optimizers import SLSQP
from qiskit_algorithms.optimizers import COBYLA
from qiskit.primitives import Estimator, BackendEstimator
from qiskit_aer import StatevectorSimulator, QasmSimulator
from qiskit_nature.second_q.algorithms import GroundStateEigensolver
from qiskit_algorithms import NumPyMinimumEigensolver
from qiskit_algorithms.utils import algorithm_globals
from qiskit_nature.second_q.operators import ElectronicIntegrals
from qiskit_nature.second_q.hamiltonians import ElectronicEnergy
from qiskit_nature.second_q.problems import ElectronicStructureProblem
from qiskit_nature.second_q.formats.molecule_info import MoleculeInfo
from qiskit_nature.second_q.circuit.library import HartreeFock, UCCSD

In [25]:
# Setting random seed to ensure consistent results
algorithm_globals.random_seed = 42

In [26]:
class GroundStateEnergyCalculator:
    """
    Ground State Energy Calculator using Variational Quantum Eigensolver (VQE).
    Its supports the calculation of the Ground State Energy (GSE) both in vacuum
    and in a solvent (The GSE in a solvent will be used for the calculation of
    the Binding Energy).
    """

    # Common solvent dielectric constants
    SOLVENTS = {
        'water': 78.3553,
        'acetonitrile': 35.688,
        'methanol': 32.613,
        'ethanol': 24.852,
        'dmso': 46.826,
        'acetone': 20.493,
        'chloroform': 4.711,
        'benzene': 2.247,
        'toluene': 2.374,
        'hexane': 1.882
    }

    # Recommended basis sets by accuracy/speed trade-off
    BASES = {
        'minimal': 'sto-3g',     # Fastest, least accurate
        'small': '6-31g',        # Good balance for small molecules
        'medium': '6-31g*',      # Includes polarization, recommended
        'large': 'cc-pvdz',      # High accuracy, slower
        'extra_large': 'cc-pvtz' # Very high accuracy, very slow
    }


    def __init__(self, estimator=None, optimizer=None, basis='6-31g*'):
        self.estimator = estimator if estimator else Estimator()
        self.optimizer = optimizer if optimizer else COBYLA(maxiter=200)
        self.basis = basis


    def compute_ground_state_energy(self, molecule_coordinates, in_solvent=False,
                                    symbols=None, solvent='water', active_space=None,
                                    env_coords=None, env_charges=None):
        """
        Calculate ground state energy for a molecule.

        Args:
            molecule_coordinates: List of atomic coordinates
            in_solvent: Whether to include solvent effects
            symbols: Atomic symbols (required for solvent calculations)
            solvent: Solvent name or dielectric constant (default: 'water')
            active_space: Tuple (num_electrons, num_orbitals) for active space
                         If None, will auto-determine based on molecule size
            env_coords: Environmental coordinates for external charges
            env_charges: Environmental charges

        Returns:
            Ground state calculation result
        """
        if in_solvent:
            if symbols is None:
                raise ValueError("Atomic symbols are required for solvent calculations")

            return self.compute_gse_in_solvent(
                molecule_coordinates, symbols, solvent=solvent, active_space=active_space,
                env_coords=env_coords, env_charges=env_charges)
        else:
            return self.compute_gse_in_vacuum(
                molecule_coordinates, active_space)


    def _determine_active_space(self, num_electrons, num_orbitals, symbols=None):
        """
        Auto-determine reasonable active space based on molecule.

        Args:
            num_electrons: Total number of electrons
            num_orbitals: Total number of orbitals
            symbols: Atomic symbols to detect molecule type

        Returns:
            Tuple (active_electrons, active_orbitals)
        """
        # For small molecules, use more conservative active spaces
        if symbols:
            num_atoms = len(symbols)
            heavy_atoms = sum(1 for s in symbols if s not in ['H'])

            if num_atoms <= 2:  # Diatomic molecules (H2, LiH, etc.)
                return (2, 2)
            elif heavy_atoms <= 2:  # Small molecules (H2O, NH3, CH4, etc.)
                return (min(4, num_electrons), min(4, num_orbitals))
            elif heavy_atoms <= 4:  # Medium molecules
                return (min(6, num_electrons), min(6, num_orbitals))
            else:  # Larger molecules
                return (min(8, num_electrons), min(8, num_orbitals))
        else:
            # Fallback based on system size
            if num_electrons <= 4:
                return (2, 2)
            elif num_electrons <= 8:
                return (4, 4)
            elif num_electrons <= 12:
                return (6, 6)
            else:
                return (8, 8)


    def _get_solvent_dielectric_constant(self, solvent):
        """Get dielectric constant by solvent name."""
        solvent_lower = solvent.lower()
        if solvent_lower in self.SOLVENTS:
            return self.SOLVENTS[solvent_lower]
        else:
            available = ', '.join(self.SOLVENTS.keys())
            raise ValueError(f"Unknown solvent '{solvent}'. Available: {available}")


    def compute_gse_in_vacuum(self, molecule_coordinates, active_space=None):
        """Calculate ground state energy in vacuum."""
        try:
            driver = PySCFDriver(
                atom=molecule_coordinates,
                basis=self.basis,
                charge=0,
                spin=0,
                unit=DistanceUnit.ANGSTROM,
            )

            es_problem = driver.run()

            if active_space is None:
                active_electrons, active_orbitals = self._determine_active_space(
                    sum(es_problem.num_particles),
                    es_problem.num_spatial_orbitals
                )
            else:
                active_electrons, active_orbitals = active_space

            print(f"Using active space: ({active_electrons}, {active_orbitals})")

            ast = ActiveSpaceTransformer(active_electrons, active_orbitals)
            es_problem = ast.transform(es_problem)

            mapper = JordanWignerMapper()

            initial_state = HartreeFock(
                es_problem.num_spatial_orbitals,
                es_problem.num_particles,
                mapper,
            )

            ansatz = UCCSD(
                es_problem.num_spatial_orbitals,
                es_problem.num_particles,
                mapper,
                initial_state=initial_state,
            )

            print(f"Number of qubits: {ansatz.num_qubits}")
            print(f"Number of parameters: {ansatz.num_parameters}")

            vqe_solver = VQE(self.estimator, ansatz, self.optimizer)

            initial_point = HFInitialPoint()
            initial_point.ansatz = ansatz
            initial_point.problem = es_problem
            vqe_solver.initial_point = initial_point.to_numpy_array()

            calc = GroundStateEigensolver(mapper, vqe_solver)

            result = calc.solve(es_problem)

            return result

        except Exception as e:
            print(f"Error while calculating GSE in vacuum: {str(e)}")
            raise


    def compute_gse_in_solvent(self, molecule_coordinates, symbols,
                                 solvent='water', active_space=None,
                                 env_coords=None, env_charges=None):
        """Calculate ground state energy in solvent."""
        try:
            dielectric = self._get_solvent_dielectric_constant(solvent)
            print(f"Using solvent: {solvent} (dielectric constant = {dielectric})")

            mol = gto.Mole()
            mol.atom = [(s, tuple(c)) for s, c in zip(symbols, molecule_coordinates)]
            mol.basis = self.basis
            mol.charge = 0
            mol.spin = 0
            mol.unit = 'Angstrom'
            mol.build()

            # Set up RHF with ddCOSMO solvent model
            mf = scf.RHF(mol)

            # Add pocket charges using QM/MM (recommended)
            # Add pocket charges using QM/MM (recommended)
            if env_coords is not None and env_charges is not None:
                if len(env_coords) != len(env_charges):
                    raise ValueError("Environmental coordinates and charges must have same length")
                
                # Convert to proper format for PySCF
                env_coords_array = np.array(env_coords, dtype=float)
                env_charges_array = np.array(env_charges, dtype=float)
                
                # Debug: Check shapes
                print(f"Coordinates shape: {env_coords}")
                print(f"Charges shape: {env_charges}")
                print(f"Coordinates shape: {env_coords_array}")
                print(f"Charges shape: {env_charges_array}")
                
                # Ensure coordinates are (N, 3) and charges are (N,)
                if env_coords_array.ndim == 1:
                    # If coords is 1D, reshape assuming it's flattened (x,y,z,x,y,z,...)
                    env_coords_array = env_coords_array.reshape(-1, 3)
                elif env_coords_array.ndim != 2 or env_coords_array.shape[1] != 3:
                    raise ValueError(f"Coordinates must be (N,3) array, got {env_coords_array.shape}")
                
                # Ensure charges are 1D
                if env_charges_array.ndim != 1:
                    env_charges_array = env_charges_array.flatten()
                
                # Final validation
                if len(env_coords_array) != len(env_charges_array):
                    raise ValueError(f"Coordinate and charge array lengths don't match: "
                                    f"{len(env_coords_array)} vs {len(env_charges_array)}")
                
                print(f"Final shapes - Coords: {env_coords_array.shape}, Charges: {env_charges_array.shape}")
                
                    # Create QM/MM object
                mf = qmmm.mm_charge(mf, env_coords_array, env_charges_array)
                print(f"Added {len(env_charges_array)} protein pocket charges via QM/MM")
                    
            # if env_coords is not None and env_charges is not None:
            #     if len(env_coords) != len(env_charges):
            #         raise ValueError("Environmental coordinates and charges must have same length")

            #     # Convert to proper format for PySCF
            #     env_coords_array = np.array(env_coords, dtype=float)
            #     env_charges_array = np.array(env_charges, dtype=float)
                                
            #     # Create QM/MM object
            #     mf = qmmm.mm_charge(mf, env_coords_array, env_charges_array)
            #     print(f"Added {len(env_charges)} protein pocket charges via QM/MM")
            
            # Then add solvent
            mf = mf.ddCOSMO()
            mf.with_solvent.eps = dielectric

            # mf = scf.RHF(mol)
            # mf.with_solvent.set(extra_charge=(env_coords_array, env_charges_array))
            # mf.set(extra_charge=(env_coords_array, env_charges_array))

            # # Add external charges if provided (for QM/MM calculations)
            # if env_coords is not None and env_charges is not None:
            #     if len(env_coords) != len(env_charges):
            #         raise ValueError("Environmental coordinates and charges must have same length")

            #     # Convert to proper format for PySCF
            #     env_coords_array = np.array(env_coords, dtype=float)
            #     env_charges_array = np.array(env_charges, dtype=float)
            #     mf.with_solvent.set(extra_charge=(env_coords_array, env_charges_array))

            #     print(f"Added {len(env_charges)} external charges")

            # mf = mf.ddCOSMO()
            # mf.with_solvent.eps = dielectric
            
            print("Running SCF calculation with solvent effects...")
            mf.kernel()

            if not mf.converged:
                print("Warning: SCF calculation did not converge")
            # else:
            #     print(f"SCF converged. Total energy: {mf.e_tot:.6f} Hartree")

            # Get solvent-corrected integrals (this includes solvent effects
            # since we used ddCOSMO)
            hcore_solv = mf.get_hcore()

            # Two-electron integrals in AO basis (these don't change with solvent)
            eri = mol.intor('int2e')

            # Molecular properties
            norb = hcore_solv.shape[0]
            nelec = mol.nelectron

            # Validate electron count - this implementation uses Restricted Hartree-Fock (RHF)
            # (Check the line where we define `mf`, above.) RHF only supports closed-shell systems
            # with paired electrons (even electron count). Open-shell systems (radicals, transition metal complexes) would require UHF and
            # modified ansatz. Most drug-like molecules from Pocket2Mol are closed-shell.
            if nelec % 2 != 0:
                raise ValueError("Odd number of electrons not supported in this implementation")

            # Determine active space
            if active_space is None:
                active_electrons, active_orbitals = self._determine_active_space(
                    nelec, norb, symbols
                )
            else:
                active_electrons, active_orbitals = active_space

            print(f"Using active space: ({active_electrons}, {active_orbitals})")

            # Create Hamiltonian with solvent-corrected one-body integrals
            hamiltonian = ElectronicEnergy.from_raw_integrals(
                h1_a=hcore_solv,  # Solvent-corrected one-body integrals
                h2_aa=eri,        # Two-body integrals (unchanged by solvent)
                h1_b=None,        # Use alpha for beta (RHF)
                h2_bb=None,       # Use alpha for beta (RHF)
                h2_ba=None        # Use alpha for beta (RHF)
            )

            es_problem = ElectronicStructureProblem(hamiltonian)
            es_problem.num_particles = (nelec//2, nelec//2)
            es_problem.num_spatial_orbitals = norb
            es_problem.basis = ElectronicBasis.AO

            mo_coeff_integrals = ElectronicIntegrals.from_raw_integrals(mf.mo_coeff)
            basis_transformer = BasisTransformer(ElectronicBasis.AO, ElectronicBasis.MO, mo_coeff_integrals)
            es_problem = basis_transformer.transform(es_problem)

            ast = ActiveSpaceTransformer(active_electrons, active_orbitals)
            es_problem = ast.transform(es_problem)

            mapper = JordanWignerMapper()

            initial_state = HartreeFock(
                es_problem.num_spatial_orbitals,
                es_problem.num_particles,
                mapper,
            )

            ansatz = UCCSD(
                es_problem.num_spatial_orbitals,
                es_problem.num_particles,
                mapper,
                initial_state=initial_state,
            )

            print(f"Number of qubits: {ansatz.num_qubits}")
            print(f"Number of parameters: {ansatz.num_parameters}")

            vqe_solver = VQE(self.estimator, ansatz, self.optimizer)

            initial_point = HFInitialPoint()
            initial_point.ansatz = ansatz
            initial_point.problem = es_problem
            vqe_solver.initial_point = initial_point.to_numpy_array()

            calc = GroundStateEigensolver(mapper, vqe_solver)
            result = calc.solve(es_problem)

            return result

        except Exception as e:
            print(f"Error while calculating GSE in solvent: {str(e)}")
            raise


    def get_energy_summary(self, result):
        """Extract and format energy information from calculation result."""
        try:
            summary = {
                'ground_state_energy': result.groundenergy,
                'total_energy': result.total_energies[0] if hasattr(result, 'total_energies') else None,
                'nuclear_repulsion_energy': result.nuclear_repulsion_energy if hasattr(result, 'nuclear_repulsion_energy') else None,
                'converged': getattr(result, 'converged', True)
            }

            if hasattr(result, 'eigenvalues'):
                summary['eigenvalue'] = result.eigenvalues[0]

            if hasattr(result, 'aux_operators_evaluated'):
                summary['auxiliary_operators'] = result.aux_operators_evaluated

            return summary

        except Exception as e:
            print(f"Error extracting energy summary: {str(e)}")
            return {'error': str(e)}

Format of the ligands coordinates file:
```js
{
    "pocket_1_name": [
        {
            "smile": "",
            "qed": qed,
            "mw": mw,
            "path": path,
            "coords": coords,
        },
        {
            // ...
        },
    ],

    "pocket_2_name": [
        {
            // ...
        },
    ],

}
```

In [27]:
class LigandCoordinateFileUtil:
    def load(self, file_path):
        """Load smiles from ligands coordinates JSON file."""

        with open(file_path, 'r') as f:
            return json.load(f)

    def save(self, smiles_data, file_path):
        """Save smiles data to JSON file."""

        with open(file_path, 'w') as f:
            return json.dump(smiles_data, f)

In [28]:
class LigandStabilityVetter:
    def __init__(self, grand_state_calculator):
        self.gse_calculator = grand_state_calculator


    def vet(self, ligands, active_space=None):
        """Load ligands coordinates JSON file, calculate GSE for ligands of
            pockets, and take the three most stable ligands for each pocket"""

        calculator = self.gse_calculator

        ligands_gse = {}
        for pocket_file_name, pocket_ligands in ligands.items():
            print(f"Processing ligands for Pocket {pocket_file_name}")
            ligands_gse[pocket_file_name] = []

            for ligand in pocket_ligands:
                print(f"Processing ligand {ligand['smiles']}")
                coords = ligand["coords"]
                gse = calculator.compute_ground_state_energy(
                    coords, active_space=active_space)

                ligands_gse[pocket_file_name].append({
                    "smiles": ligand["smiles"],
                    "coords": ligand["coords"],
                    "qed": ligand["qed"],
                    "mw": ligand["mw"],
                    "path": ligand["path"],
                    "gse": gse.groundenergy,
                })
            print()

        # sort ligands_gse for each pocket based on gse and take lowest three
        ligands_gse_sorted = {}

        for pocket_file_name, pocket_ligands in ligands_gse.items():
            ligands_gse_sorted[pocket_file_name] = sorted(
                pocket_ligands, key=lambda x: x["gse"])[:3]

            ligands_gse_sorted[pocket_file_name] = sorted(
                ligands_gse_sorted[pocket_file_name], key=lambda x: x["qed"],
                reverse=True)

        return ligands_gse_sorted

In [29]:
# # On Colab
# base_dir = "/content/"
# input_dir = os.path.join(base_dir, "input")
# os.makedirs(input_dir, exist_ok=True)
# output_dir = os.path.join(base_dir, "output")
# os.makedirs(output_dir, exist_ok=True)
# toolkit_dir = os.path.join(base_dir, "toolkit")
# os.makedirs(toolkit_dir, exist_ok=True)

# pockets_sub_dir = "pockets"
# ligands_sub_dir = "ligands"

# # Upload the JSON files of the ligands coordinates
# from google.colab import files
# uploaded = files.upload()

# generated_pockets_and_ligands_file = "PfEMP1_pockets_ligands.zip"
# generated_pockets_and_ligands_prefix = os.path.splitext(generated_pockets_and_ligands_file)[0]

In [30]:
# On Kaggle
base_dir = "/kaggle/"
input_dir = os.path.join(base_dir, "input/pfemp1-pockets-ligands")
output_dir = os.path.join(base_dir, "working/output")
toolkit_dir = os.path.join(base_dir, "working/toolkit")
os.makedirs(toolkit_dir, exist_ok=True)

pockets_sub_dir = "pockets"
ligands_sub_dir = "ligands"

generated_pockets_and_ligands_prefix = "PfEMP1_pockets_ligands"

In [31]:
generated_pockets_and_ligands_dir = os.path.join(input_dir, generated_pockets_and_ligands_prefix)
generated_pockets_dir = os.path.join(generated_pockets_and_ligands_dir, pockets_sub_dir)
generated_ligands_dir = os.path.join(generated_pockets_and_ligands_dir, ligands_sub_dir)

ligand_coordinates_file = os.path.join(generated_ligands_dir, "ligands.json")

In [32]:
# # On Colab Unzip to input_dir
# import zipfile

# with zipfile.ZipFile(generated_pockets_and_ligands_file, 'r') as zip_ref:
#     zip_ref.extractall(input_dir)

In [33]:
ligands_coords_file_util = LigandCoordinateFileUtil()
all_ligands = ligands_coords_file_util.load(ligand_coordinates_file)

In [34]:
for pocket_file_name, pocket_ligands in all_ligands.items():
    print(f"Selected first 5 ligands for Pocket {pocket_file_name}:")
    for ligand in pocket_ligands:
        print(f"{ligand['smiles']} | QED: {ligand['qed']:.2f} | MW: {ligand['mw']:.2f}")
    print()

Selected first 5 ligands for Pocket pocket1_atm.pdb:
O=C(O)CCc1cccc(CO)c1 | QED: 0.73 | MW: 180.20
COc1ccc2c(O)cccc2c1 | QED: 0.72 | MW: 174.20
COC1CCc2c(O)cccc2C1 | QED: 0.71 | MW: 178.23
O=C(O)COc1ccccc1 | QED: 0.71 | MW: 152.15
O=C(O)c1cccc(C(=O)O)c1 | QED: 0.69 | MW: 166.13

Selected first 5 ligands for Pocket pocket2_atm.pdb:
O=C(O)CCc1ccccc1 | QED: 0.71 | MW: 150.18
COc1cccc(C(=O)O)c1 | QED: 0.70 | MW: 152.15
CC(=O)NCc1ccccc1 | QED: 0.67 | MW: 149.19
O=C(O)C=Cc1ccccc1 | QED: 0.65 | MW: 148.16
O=C(O)c1ccccc1 | QED: 0.61 | MW: 122.12

Selected first 5 ligands for Pocket pocket3_atm.pdb:
COc1cccc(O)c1OC | QED: 0.70 | MW: 154.16
Oc1cccc(-c2ccccc2)c1 | QED: 0.70 | MW: 170.21
Oc1ccc(-c2ccccc2)cc1 | QED: 0.70 | MW: 170.21
NC(=O)c1ccc(C(=O)O)cc1 | QED: 0.67 | MW: 165.15
COc1ccc(C)cc1OC | QED: 0.65 | MW: 152.19

Selected first 5 ligands for Pocket pocket4_atm.pdb:
CNc1ccccc1 | QED: 0.58 | MW: 107.16
NCc1ccccc1 | QED: 0.57 | MW: 107.16
COc1ccccc1 | QED: 0.53 | MW: 108.14
Cc1cncnc1N | QED: 

In [35]:
basis = GroundStateEnergyCalculator.BASES['minimal']
grand_state_calculator = GroundStateEnergyCalculator(basis=basis)

  self.estimator = estimator if estimator else Estimator()


In [36]:
ligand_vetting_module = LigandStabilityVetter(grand_state_calculator)

In [37]:
# We enforce active space here to reduce the problem as much as possible (for speed sake).
# This definitely has an impact on the accuracy of the computed GSE.
# The ideal thing to do would be to let the GSE Calculator detect the right
# active space. (It is implemented. Check `_determine_active_space` in the
# `GroundStateEnergyCalculator` class.)
# The trade-off is that, GSE calculation will take more time for each ligand.
# Feel free to set active_space to None if you want that behavior.
active_space = (2,2)
selected_ligands = {}

selected_ligands_dir = os.path.join(output_dir, generated_pockets_and_ligands_prefix)
os.makedirs(selected_ligands_dir, exist_ok=True)

selected_ligands_after_vetting_1_path = os.path.join(selected_ligands_dir, "selected_ligands_after_vetting_1.json")

force_selection = False
if not os.path.isfile(selected_ligands_after_vetting_1_path) or force_selection:
    selected_ligands = ligand_vetting_module.vet(all_ligands, active_space=active_space)
    ligands_coords_file_util.save(selected_ligands, selected_ligands_after_vetting_1_path)
else:
    print(f"Ligands have already been selected in vetting round 1.\nLoading already selected at: {selected_ligands_after_vetting_1_path}")
    selected_ligands = ligands_coords_file_util.load(selected_ligands_after_vetting_1_path)

Ligands have already been selected in vetting round 1.
Loading already selected at: /kaggle/working/output/PfEMP1_pockets_ligands/selected_ligands_after_vetting_1.json


In [38]:
for pocket_file_name, pocket_ligands in selected_ligands.items():
    print(f"Selected top 3 ligands for Pocket {pocket_file_name}:")
    for ligand in pocket_ligands:
        print(f"{ligand['smiles']} | GSE: {ligand['gse']}")
    print()

Selected top 3 ligands for Pocket pocket1_atm.pdb:
O=C(O)CCc1cccc(CO)c1 | GSE: -0.8710423057125096
O=C(O)COc1ccccc1 | GSE: -0.8200483888728851
O=C(O)c1cccc(C(=O)O)c1 | GSE: -0.8403112242835951

Selected top 3 ligands for Pocket pocket2_atm.pdb:
O=C(O)CCc1ccccc1 | GSE: -0.8750926619590413
COc1cccc(C(=O)O)c1 | GSE: -0.8148779340464601
O=C(O)c1ccccc1 | GSE: -0.8631357554851568

Selected top 3 ligands for Pocket pocket3_atm.pdb:
COc1cccc(O)c1OC | GSE: -0.749759730236824
NC(=O)c1ccc(C(=O)O)cc1 | GSE: -0.8097061348704405
COc1ccc(C)cc1OC | GSE: -0.7302808794323312

Selected top 3 ligands for Pocket pocket4_atm.pdb:
NCc1ccccc1 | GSE: -0.8599308297497598
Cc1cncnc1N | GSE: -0.8436204242760501
Cn1ccc(C=N)c1 | GSE: -0.8201674570880849



## 4. Select most stable binding ligand

In [39]:
!pip install rdkit Bio --quiet

### Dock ligands into the protein to produce complex

In [40]:
!pip install openbabel-wheel vina --quiet

In [41]:
# !find /kaggle/working/toolkit/mgltools -name prepare_ligand4.py
# !find /kaggle/working/toolkit/mgltools -name prepare_receptor4.py

In [42]:
import os
import shutil
import subprocess
import tarfile
import urllib.request
import numpy as np
from vina import Vina
from rdkit import Chem
from rdkit.Chem import AllChem
from openbabel import openbabel as ob, pybel



class VinaDocking:
    """
    Class to dock ligand into protein using AutoDock Vina.
    """
    def __init__(self, work_dir, toolkit_dir, force_setup=False, use_mgl_tools=False):
        self.work_dir = work_dir
        self.toolkit_dir = toolkit_dir
        self.force_setup = force_setup
        self.use_mgl_tools = use_mgl_tools
        self.vina = Vina(sf_name='vina')

        self.mgltools_pythonsh = None
        self.mgltools_utilities_dir = None

        if self.use_mgl_tools:
            self.download_and_setup_mgltools()


    def download_and_setup_mgltools(self):
        extract_dir = os.path.join(self.toolkit_dir, "mgltools")
        install_dir = os.path.join(extract_dir, os.listdir(extract_dir)[0])

        install_script = os.path.join(install_dir, "install.sh")
        installed_dir = os.path.join(install_dir, "installed")

        if not os.path.isdir(extract_dir) or self.force_setup:            
            mgltools_url = "https://ccsb.scripps.edu/mgltools/download/491/"
            archive_name = os.path.join(self.toolkit_dir, "mgltools.tar.gz")
    
            print("Downloading MGLTools...")
            urllib.request.urlretrieve(mgltools_url, archive_name)
    
            print("Extracting MGLTools...")
            os.makedirs(extract_dir, exist_ok=True)
            with tarfile.open(archive_name, "r:gz") as tar:
                tar.extractall(path=extract_dir)

            print("Running install.sh...")
            subprocess.run(["bash", install_script, "-d", installed_dir], cwd=install_dir, check=True)
            print("MGLTools installed.")
        else:
            print("MGLTools already installed.\nTo force reinstallation, consider instantiating this class with the `force_setup` argument set to True.")

        # Setup paths
        bin_dir = os.path.join(installed_dir, "bin")
        pythonsh = os.path.join(bin_dir, "pythonsh")

        utilities_dir = os.path.join(installed_dir, "MGLToolsPckgs", "AutoDockTools", "Utilities24")

        if not os.path.exists(pythonsh):
            raise RuntimeError("Could not find pythonsh. Check installation.")

        os.chmod(pythonsh, 0o755)

        self.mgltools_pythonsh = pythonsh
        self.mgltools_utilities_dir = utilities_dir

        print("MGLTools setup complete.")
        return {
            "pythonsh": pythonsh,
            "utilities_dir": utilities_dir
        }


    def preprocess_protein_with_openbabel(self, input_pdb, output_pdb):
        mol = next(pybel.readfile("pdb", input_pdb))
        mol.addh()     # Add hydrogens
        mol.make3D()
        mol.calccharges(model='gasteiger')
        mol.write("pdb", output_pdb, overwrite=True)


    def prepare_receptor(self, input_pdb, output_pdbqt=None):
        if self.use_mgl_tools:
            return self.prepare_receptor_mgl(input_pdb, output_pdbqt)
        else:
            return self.prepare_receptor_openbabel(input_pdb, output_pdbqt)

    
    def prepare_receptor_mgl(self, input_pdb, output_pdbqt=None):
        self.preprocess_protein_with_openbabel(input_pdb, input_pdb)

        pythonsh_path = self.mgltools_pythonsh
        utilities_path = self.mgltools_utilities_dir

        if output_pdbqt is None:
            output_pdbqt_name = self.change_extension_to(os.path.basename(input_pdb), "pdbqt")
            output_pdbqt = os.path.join(self.work_dir, output_pdbqt_name)


        script = os.path.join(utilities_path, "prepare_receptor4.py")
        cmd = [pythonsh_path, script, "-r", input_pdb, "-o", output_pdbqt, "-A", "hydrogens", "-C", "c"] # -C : do not add charges
        print(f"Running: {' '.join(cmd)}")
        subprocess.run(cmd, check=True)
        print(f"Receptor prepared: {output_pdbqt}")

        return output_pdbqt


    # def prepare_receptor_openbabel(self, input_pdb, output_pdbqt=None):
    #     if output_pdbqt is None:
    #         output_pdbqt_name = self.change_extension_to(os.path.basename(input_pdb), "pdbqt")
    #         output_pdbqt = os.path.join(self.work_dir, output_pdbqt_name)

    #     print(f"Saving receptor at {output_pdbqt}")

    #     mol = next(pybel.readfile("pdb", input_pdb))
    #     mol.addh()
    #     mol.write("pdbqt", output_pdbqt, overwrite=True)
    #     print(f"Saved receptor.")

    #     return output_pdbqt

    # def prepare_receptor_openbabel(self, input_pdb, output_pdbqt=None):
    #     if output_pdbqt is None:
    #         output_pdbqt_name = self.change_extension_to(os.path.basename(input_pdb), "pdbqt")
    #         output_pdbqt = os.path.join(self.work_dir, output_pdbqt_name)
    #     output_file = os.path.join(self.work_dir, output_pdbqt)
    #     print(f"Saving receptor at {output_pdbqt}")

    #     conv = ob.OBConversion()
    #     conv.SetInAndOutFormats("pdb", "pdbqt")
        
    #     mol = ob.OBMol()
    #     conv.ReadFile(mol, input_pdb)
    #     mol.AddHydrogens()
    #     conv.WriteFile(mol, output_file)
    #     print(f"Saved receptor.")
        
    #     return output_file

    # def prepare_receptor_openbabel(self, input_pdb, output_pdbqt=None):
    #     """Convert protein to PDBQT format suitable for Vina receptor."""
    #     if output_pdbqt is None:
    #         output_pdbqt_name = self.change_extension_to(os.path.basename(input_pdb), "pdbqt")
    #         output_pdbqt = os.path.join(self.work_dir, output_pdbqt_name)
        
    #     # Convert PDB to PDBQT for receptor (no ROOT/ENDROOT tags)
    #     self._convert_protein_to_pdbqt(input_pdb, output_pdbqt)
        
    #     return output_pdbqt


    def prepare_receptor_openbabel(self, input_pdb, output_pdbqt=None):
        if output_pdbqt is None:
            output_pdbqt_name = self.change_extension_to(os.path.basename(input_pdb), "pdbqt")
            output_pdbqt = os.path.join(self.work_dir, output_pdbqt_name)

        # Use direct OpenBabel API for more control
        conv = ob.OBConversion()
        conv.SetInAndOutFormats("pdb", "pdbqt")

        # Add options for receptor preparation
        conv.AddOption("r", conv.OUTOPTIONS)  # Rigid macromolecule
        conv.AddOption("x", conv.OUTOPTIONS)  # Disable automatic addition of charges

        # Read molecule
        mol = ob.OBMol()
        conv.ReadFile(mol, input_pdb)

        # Add hydrogens but skip charge calculation
        mol.AddHydrogens()

        # Write output file
        conv.WriteFile(mol, output_pdbqt)

        print(f"Receptor prepared: {output_pdbqt}")
        return output_pdbqt


    def _convert_protein_to_pdbqt(self, pdb_file, pdbqt_file):
        """Convert protein PDB to PDBQT format without ligand-specific tags."""
        # First add hydrogens using OpenBabel
        temp_pdb_h = os.path.join(self.work_dir, "temp_protein_h.pdb")
        print("Adding hydrogens")
        self._add_hydrogens_to_protein(pdb_file, temp_pdb_h)
        
        print("Converting to PDBQT")
        # Then convert to PDBQT format
        with open(temp_pdb_h, 'r') as f_in, open(pdbqt_file, 'w') as f_out:
            for line in f_in:
                if line.startswith(('ATOM', 'HETATM')):
                    # Convert PDB line to PDBQT format
                    # Add partial charges (simplified - set to 0.000)
                    pdbqt_line = line.rstrip() + "  0.00  0.00    +0.000 "
                    
                    # Add atom type (simplified mapping)
                    atom_name = line[12:16].strip()
                    element = line[76:78].strip() if len(line) > 76 else atom_name[0]
                    
                    # Simple atom type mapping
                    if element == 'C':
                        atom_type = 'C'
                    elif element == 'N':
                        atom_type = 'N'
                    elif element == 'O':
                        atom_type = 'O'
                    elif element == 'S':
                        atom_type = 'S'
                    elif element == 'H':
                        atom_type = 'H'
                    else:
                        atom_type = element
                    
                    pdbqt_line += atom_type + "\n"
                    f_out.write(pdbqt_line)
                elif line.startswith(('REMARK', 'HEADER', 'TITLE')):
                    f_out.write(line)
        
        # Clean up temporary file
        if os.path.exists(temp_pdb_h):
            os.remove(temp_pdb_h)
    
    def _add_hydrogens_to_protein(self, input_pdb, output_pdb):
        """Add hydrogens to protein PDB file using OpenBabel."""
        conv = ob.OBConversion()
        conv.SetInAndOutFormats("pdb", "pdb")
        
        mol = ob.OBMol()
        conv.ReadFile(mol, input_pdb)
        mol.AddHydrogens()
        
        conv.WriteFile(mol, output_pdb)

    

    def prepare_ligand(self, input_pdb, output_pdbqt=None):
        if self.use_mgl_tools:
            return self.prepare_ligand_mgl(input_pdb, output_pdbqt)
        else:
            return self.prepare_ligand_openbabel(input_pdb, output_pdbqt)

        
    def prepare_ligand_mgl(self, input_pdb, output_pdbqt=None):
        """Convert ligand PDB to PDBQT format using MGLTools."""
        pythonsh_path = self.mgltools_pythonsh
        utilities_path = self.mgltools_utilities_dir

        if output_pdbqt is None:
            output_pdbqt_name = self.change_extension_to(os.path.basename(input_pdb), "pdbqt")
            output_pdbqt = os.path.join(self.work_dir, output_pdbqt_name)

        script = os.path.join(utilities_path, "prepare_ligand4.py")
        cmd = [pythonsh_path, script, "-l", input_pdb, "-o", output_pdbqt, "-A", "hydrogens"]
        print(f"Running: {' '.join(cmd)}")
        subprocess.run(cmd, check=True)
        print(f"Ligand prepared: {output_pdbqt}")

        return output_pdbqt


    def  prepare_ligand_openbabel(self, input_pdb, output_pdbqt=None):
        """Convert ligand PDB to PDBQT format using Openbabel."""
        if output_pdbqt is None:
            output_pdbqt_name = self.change_extension_to(os.path.basename(input_pdb), "pdbqt")
            output_pdbqt = os.path.join(self.work_dir, output_pdbqt_name)

        return self.convert_to_pdbqt(input_pdb, output_pdbqt)

    
    def convert_to_pdbqt(self, original_file, output_file=None):
        """Convert PDBQT format."""
        original_file_extension = os.path.splitext(os.path.basename(original_file))[1][1:].lower()
        new_file = output_file

        if new_file is None:
            new_file = f"{os.path.splitext(original_file)[0]}.pdbqt"

        conv = ob.OBConversion()
        conv.SetInAndOutFormats(original_file_extension, "pdbqt")

        conv.AddOption("r", ob.OBConversion.OUTOPTIONS)   # Add ROOT/ENDROOT
        conv.AddOption("xr", ob.OBConversion.OUTOPTIONS) 

        mol = ob.OBMol()
        conv.ReadFile(mol, original_file)
        mol.AddHydrogens()
        conv.WriteFile(mol, new_file)

        return new_file

    
    def convert_sdf_to_pdb(self, input_sdf, output_pdb=None):
        if output_pdb is None:
            output_pdb_name = self.change_extension_to(os.path.basename(input_sdf), "pdb")
            output_pdb = os.path.join(self.work_dir, output_pdb_name)

        mol = next(pybel.readfile("sdf", input_sdf))
        mol.addh()
        mol.make3D()
        mol.write("pdb", output_pdb, overwrite=True)

        return output_pdb


    def change_extension_to(self, file_name, extension):
        return f"{os.path.splitext(file_name)[0]}.{extension}"

    
    def dock_ligand(self, protein_pdbqt, ligand_pdbqt, center, size=(20, 20, 20)):
        """Perform molecular docking."""        
        # Set receptor and ligand
        self.vina.set_receptor(protein_pdbqt)
        self.vina.set_ligand_from_file(ligand_pdbqt)
        
        # Compute maps and dock
        self.vina.compute_vina_maps(center=list(center), box_size=list(size))
        self.vina.dock(exhaustiveness=32, n_poses=10)
        
        # Write results
        docked_ligand_pocket_dir = self.resolve_ligand_pocket_dir(ligand_pdbqt)
        ligand_file_name = self.resolve_ligand_file_name(ligand_pdbqt)
        
        docked_ligand_file_name = f"{ligand_file_name}_docked.pdbqt"
        
        output_file = os.path.join(docked_ligand_pocket_dir, docked_ligand_file_name)
        self.vina.write_poses(output_file, n_poses=10, overwrite=True)
        
        # Get energies
        energies = self.vina.energies(n_poses=10)
        
        return {
            'output_file': output_file,
            'best_affinity': energies[0][0],
            'energies': energies
        }

    
    def extract_best_pose(self, docking_output):
        """Extract the best pose from docking results."""
        input_file = docking_output['output_file']
        output_file = f"{input_file.split('_docked.pdbqt')[0]}_best_pose.pdbqt"

        with open(input_file, 'r') as f:
            lines = f.readlines()
        
        # Extract first pose
        pose_lines = []
        in_first_pose = False
        
        for line in lines:
            if line.startswith("MODEL 1"):
                in_first_pose = True
                pose_lines.append(line)
            elif line.startswith("ENDMDL") and in_first_pose:
                pose_lines.append(line)
                break
            elif in_first_pose:
                pose_lines.append(line)
        
        with open(output_file, 'w') as f:
            f.writelines(pose_lines)
        
        return output_file

        
    def extract_pocket_region(self, protein_atoms, ligand_pose_file, cutoff=10.0):
        """Extract protein atoms within cutoff distance of ligand."""
        # Get ligand coordinates
        ligand_coords = self._get_coordinates(ligand_pose_file)
        
        # Find pocket atoms
        pocket_atoms = []
        for atom in protein_atoms:
            protein_coord = np.array([atom['x'], atom['y'], atom['z']])
            distances = np.linalg.norm(ligand_coords - protein_coord, axis=1)
            if np.min(distances) <= cutoff:
                pocket_atoms.append(atom)
        
        # Write files
        pocket_dir = self.resolve_ligand_pocket_dir(ligand_pose_file)
        ligand_name = self.resolve_ligand_file_name(ligand_pose_file)
        
        pocket_pdb = os.path.join(pocket_dir, f"{ligand_name}_pocket.pdb")
        ligand_pdb = os.path.join(pocket_dir, f"{ligand_name}_ligand.pdb")
        complex_pdb = os.path.join(pocket_dir, f"{ligand_name}_complex.pdb")
        
        self._write_pocket_pdb(pocket_atoms, pocket_pdb)
        self._write_ligand_pdb(ligand_pose_file, ligand_pdb)
        self._write_complex_pdb(pocket_atoms, ligand_pose_file, complex_pdb)
        
        return {
            'pocket_pdb': pocket_pdb,
            'ligand_pdb': ligand_pdb,
            'complex_pdb': complex_pdb,
            'n_pocket_atoms': len(pocket_atoms)
        }


    def resolve_ligand_pocket_dir(self, ligand_path):
        pocket_name = os.path.basename(os.path.dirname(ligand_path))
        pocket_dir = os.path.join(self.work_dir, pocket_name)
        os.makedirs(pocket_dir, exist_ok=True)
        return pocket_dir

    
    def resolve_ligand_file_name(self, ligand_file_path):
        return os.path.splitext(os.path.basename(ligand_file_path))[0]


    
    def prepare_vqe_data(self, pocket_files, ligand_pose_file):
        """Prepare coordinate data for VQE calculations."""
        ligand_coords = self._get_coordinates(ligand_pose_file)
        ligand_symbols = self._get_symbols(ligand_pose_file)
        
        pocket_coords, pocket_symbols = self._get_coords_from_pdb(pocket_files['pocket_pdb'])
        complex_coords, complex_symbols = self._get_coords_from_pdb(pocket_files['complex_pdb'])
        
        return {
            'ligand': {
                'coordinates': ligand_coords,
                'symbols': ligand_symbols
            },
            'pocket': {
                'coordinates': pocket_coords,
                'symbols': pocket_symbols
            },
            'complex': {
                'coordinates': complex_coords,
                'symbols': complex_symbols
            }
        }
    
    def _get_coordinates(self, pose_file):
        """Extract coordinates from PDBQT file."""
        coordinates = []
        with open(pose_file, 'r') as f:
            for line in f:
                if line.startswith("ATOM") or line.startswith("HETATM"):
                    x = float(line[30:38].strip())
                    y = float(line[38:46].strip())
                    z = float(line[46:54].strip())
                    coordinates.append([x, y, z])
        return np.array(coordinates)
    
    def _get_symbols(self, pose_file):
        """Extract atomic symbols from PDBQT file."""
        symbols = []
        with open(pose_file, 'r') as f:
            for line in f:
                if line.startswith("ATOM") or line.startswith("HETATM"):
                    element = line[76:78].strip()
                    if not element:
                        element = line[12:16].strip()[0]
                    symbols.append(element)
        return symbols
    
    def read_protein_atoms(self, protein_file):
        """Read protein atoms from PDB file."""
        atoms = []
        with open(protein_file, 'r') as f:
            for line in f:
                if line.startswith('ATOM') or line.startswith('HETATM'):
                    try:
                        atom = {
                            'line': line.strip(),
                            'x': float(line[30:38].strip()),
                            'y': float(line[38:46].strip()),
                            'z': float(line[46:54].strip())
                        }
                        atoms.append(atom)
                    except (ValueError, IndexError):
                        continue
        return atoms
    
    def _get_coords_from_pdb(self, pdb_file):
        """Extract coordinates and symbols from PDB file."""
        coordinates = []
        symbols = []
        
        with open(pdb_file, 'r') as f:
            for line in f:
                if line.startswith('ATOM') or line.startswith('HETATM'):
                    x = float(line[30:38].strip())
                    y = float(line[38:46].strip())
                    z = float(line[46:54].strip())
                    coordinates.append([x, y, z])
                    
                    element = line[76:78].strip()
                    if not element:
                        element = line[12:16].strip()[0]
                    symbols.append(element)
        
        return np.array(coordinates), symbols
    
    def _write_pocket_pdb(self, pocket_atoms, output_file):
        """Write pocket atoms to PDB file."""
        with open(output_file, 'w') as f:
            for i, atom in enumerate(pocket_atoms, 1):
                line = atom['line']
                new_line = f"{line[:6]}{i:5d}{line[11:]}\n"
                f.write(new_line)
            f.write("END\n")
    
    def _write_ligand_pdb(self, ligand_pdbqt_file, output_file):
        """Convert ligand from PDBQT to PDB format."""
        with open(ligand_pdbqt_file, 'r') as f_in, open(output_file, 'w') as f_out:
            atom_num = 1
            for line in f_in:
                if line.startswith('ATOM') or line.startswith('HETATM'):
                    pdb_line = f"HETATM{atom_num:5d}  {line[12:30]}LIG A   1    {line[30:54]}  1.00 20.00           {line[77:79]}\n"
                    f_out.write(pdb_line)
                    atom_num += 1
            f_out.write("END\n")

    
    def _write_complex_pdb(self, pocket_atoms, ligand_pdbqt_file, output_file):
        """Write complex PDB file."""
        with open(output_file, 'w') as f:
            atom_num = 1
            
            # Write pocket atoms
            for atom in pocket_atoms:
                line = atom['line']
                new_line = f"{line[:6]}{atom_num:5d}{line[11:]}\n"
                f.write(new_line)
                atom_num += 1
            
            # Write ligand atoms
            with open(ligand_pdbqt_file, 'r') as lig_f:
                for line in lig_f:
                    if line.startswith('ATOM') or line.startswith('HETATM'):
                        pdb_line = f"HETATM{atom_num:5d}  {line[12:30]}LIG A   1    {line[30:54]}  1.00 20.00           {line[77:79]}\n"
                        f.write(pdb_line)
                        atom_num += 1
            
            f.write("END\n")
    
    
    def get_binding_site_from_fpocket(self, fpocket_pdb):
        """Extract binding site center and size from fpocket output."""
        coordinates = []
        
        with open(fpocket_pdb, 'r') as f:
            for line in f:
                if line.startswith('ATOM') or line.startswith('HETATM'):
                    x = float(line[30:38].strip())
                    y = float(line[38:46].strip())
                    z = float(line[46:54].strip())
                    coordinates.append([x, y, z])
        
        coords = np.array(coordinates)
        center = coords.mean(axis=0)
        
        min_coords = coords.min(axis=0)
        max_coords = coords.max(axis=0)
        size = max_coords - min_coords + 5
        size = np.maximum(size, 15.0)
        
        return tuple(center), tuple(size)

In [43]:
docked_ligands_dir = f"{selected_ligands_dir}/docked_ligands"
os.makedirs(docked_ligands_dir, exist_ok=True)
cutoff=10.0

protein_pdb = os.path.join(generated_pockets_and_ligands_dir, "PfEMP1.pdb")

docking = VinaDocking(work_dir=docked_ligands_dir, toolkit_dir=toolkit_dir, force_setup=True, use_mgl_tools=False)

In [None]:
preprocessed_protein_pdb = os.path.join(docked_ligands_dir, os.path.basename(protein_pdb))
shutil.copy(protein_pdb, preprocessed_protein_pdb)

print(f"Preparing receptor {preprocessed_protein_pdb}:")
protein_pdbqt = docking.prepare_receptor(preprocessed_protein_pdb)

docking_results = {}

for pocket_file_name, pocket_ligands in selected_ligands.items():
    print(f"Docking ligands for pockets {pocket_file_name}:")
    pocket_prefix = os.path.splitext(pocket_file_name)[0]

    docking_results[pocket_file_name] = {}
    pocket_path = os.path.join(generated_pockets_dir, pocket_file_name)
    ligands_dir = os.path.join(generated_ligands_dir, pocket_prefix)
    docked_ligands_output_dir = os.path.join(docked_ligands_dir, pocket_prefix)
    os.makedirs(docked_ligands_output_dir, exist_ok=True)

    # Get binding site from pocket
    print(f"Getting binding site from pockets {pocket_path}:")
    center, size = docking.get_binding_site_from_fpocket(pocket_path)

    for ligand in pocket_ligands:
        # print(ligand)
        ligand_smiles = ligand['smiles']
        ligand_base_name = os.path.basename(ligand['path'])
        ligand_sdf_path = os.path.join(ligands_dir, ligand_base_name)

        print(f"Docking ligand {ligand_smiles} from {ligand_sdf_path}")
        docking_results[pocket_file_name][ligand_smiles] = {}

        # Convert ligand to pdbqt
        ligand_pdb = docking.convert_sdf_to_pdb(ligand_sdf_path)
        ligand_pdbqt = docking.prepare_ligand(ligand_pdb)

        # Dock ligand
        docking_results = docking.dock_ligand(protein_pdbqt, ligand_pdbqt, center, size)

        # Extract best pose
        best_pose_file = docking.extract_best_pose(docking_results)

        # Extract pocket region        
        protein_atoms = docking.read_protein_atoms(protein_pdb)
        pocket_files = docking.extract_pocket_region(protein_atoms, best_pose_file, cutoff)
        
        # Prepare VQE data
        vqe_data = docking.prepare_vqe_data(pocket_files, best_pose_file)

        docking_results[pocket_file_name][ligand_smiles] = {
            'binding_affinity': docking_results['best_affinity'],
            'pocket_files': pocket_files,
            'vqe_data': vqe_data,
            'work_dir': work_dir
        }
        print(f"Binding affinity: {results['binding_affinity']:.2f} kcal/mol")
        print(f"Pocket atoms: {results['pocket_files']['n_pocket_atoms']}")
        print(f"Files in: {results['work_dir']}")
        print()
    print()

In [None]:
import os
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.NeighborSearch import NeighborSearch


class LigandProteinBindingEnergyCalculator:
    def __init__(self, pocket_pdb_file, smiles, ground_state_energy_calculator, distance_cutoff=6.0, solvent='water', active_space=None):

        self.pdb_file = pocket_pdb_file # Pocket
        self.smiles = smiles
        self.distance_cutoff = distance_cutoff

        self.ligand_mol = Chem.MolFromSmiles(smiles)
        self.ligand_mol = Chem.AddHs(self.ligand_mol)
        AllChem.EmbedMolecule(self.ligand_mol, AllChem.ETKDG())
        AllChem.UFFOptimizeMolecule(self.ligand_mol)
        AllChem.ComputeGasteigerCharges(self.ligand_mol)

        self.ligand_coords = self._get_rdkit_coords(self.ligand_mol)
        self.ligand_symbols = [atom.GetSymbol() for atom in self.ligand_mol.GetAtoms()]
        self.ligand_charges = [float(atom.GetProp('_GasteigerCharge')) for atom in self.ligand_mol.GetAtoms()]

        self.pocket_coords, self.pocket_charges = self._extract_pocket_charges()

        self.gse_calculator = ground_state_energy_calculator
        self.solvent = solvent
        self.active_space = active_space


    def _get_rdkit_coords(self, mol):
        conf = mol.GetConformer()
        return np.array([[conf.GetAtomPosition(i).x,
                          conf.GetAtomPosition(i).y,
                          conf.GetAtomPosition(i).z] for i in range(mol.GetNumAtoms())])


    def _extract_pocket_charges(self):
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", self.pdb_file)
        atoms = list(structure.get_atoms())
        ns = NeighborSearch(atoms)
        nearby = set()
        for x, y, z in self.ligand_coords:
            nearby.update(ns.search((x, y, z), self.distance_cutoff))
        coords = []
        charges = []

        for atom in nearby:
            coords.append(atom.coord)
            charges.append(self._estimate_charge_from_element(atom.element))

        return np.array(coords), np.array(charges)


    def _estimate_charge_from_element(self, element):
        mmff_like = {
            'O': -0.55,
            'N': -0.45,
            'S': -0.25,
            'H': 0.25,
            'C': 0.15
        }
        return mmff_like.get(element, 0.0)


    def compute_ligand_energy_in_solvent(self):
        """Compute energy of ligand in solvent."""

        return self.gse_calculator.compute_ground_state_energy(
            self.ligand_coords,
            True,
            self.ligand_symbols,
            self.solvent,
            self.active_space,
        )


    def compute_complex_energy_in_solvent(self):
        return self.gse_calculator.compute_ground_state_energy(
            self.ligand_coords,
            True,
            self.ligand_symbols,
            self.solvent,
            self.active_space,
            self.pocket_coords,
            self.pocket_charges,
        )


    def compute_binding_energy(self, ligand_energy=None):

        print("Calculating energy ligand-in-pocket in solvent")
        E_complex = self.compute_complex_energy_in_solvent()
        print()
        
        print("Calculating ligand energy in solvent")
        E_lig = ligand_energy if ligand_energy else self.compute_ligand_energy_in_solvent()
        print()

        return {
            "e_lig": E_lig.groundenergy,
            "e_complex": E_complex.groundenergy,
            "e_bind": E_complex.groundenergy - E_lig.groundenergy,
        }

In [None]:
class BindingEnergiesVetter:
    def __init__(self, grand_state_calculator, ligands, solvent='water', active_space=None, pockets_dir=""):
        self.grand_state_calculator = grand_state_calculator
        self.ligands = ligands
        self.solvent = solvent
        self.active_space = active_space
        self.pockets_dir = pockets_dir
        self.binding_energies = {}
        self.most_stable_ligand_overall = {}
        self.most_stable_ligand_per_pocket = {}


    def vet(self):
        for pocket_name, pocket_ligands in self.ligands.items():
            print(f"Processing ligands for Pocket {pocket_name}")
            self.binding_energies[pocket_name] = []
            self.most_stable_ligand_per_pocket[pocket_name] = {}

            pocket_path = os.path.join(self.pockets_dir, pocket_name)

            for ligand in pocket_ligands:

                binding_energy_result = self.compute_binding_energy(
                    pocket_path,
                    ligand['smiles'],
                    # ligand['gse'],
                )

                self.binding_energies[pocket_name].append({
                    "smile": ligand['smiles'],
                    "pocket": pocket_name,
                    "qed": ligand['qed'],
                    "mw": ligand['mw'],
                    "path": ligand['path'],
                    "gse": ligand['gse'],
                    "gse_in_solvent": binding_energy_result['e_lig'],
                    "gse_ligand_in_pocket": binding_energy_result['e_complex'],
                    "binding_energy": binding_energy_result['e_bind'],
                })

                print(f"Binding energy for {ligand['smiles']} in {pocket_name}: {binding_energy_result['binding_energy']}")

            # sort ligands_gse for each pocket based on gse and take lowest three
            self.most_stable_ligand_per_pocket[pocket_name] = sorted(
                self.binding_energies[pocket_name], key=lambda x: x["binding_energy"])[0]
            print()

        # sort ligands_gse for each pocket based on gse and take lowest
        self.most_stable_ligand_overall = sorted(
            self.most_stable_ligand_per_pocket.items(), key=lambda x: x[1]["binding_energy"])[0]

        return {
            "most_stable_ligand_overall": self.most_stable_ligand_overall,
            "most_stable_ligand_per_pocket": self.most_stable_ligand_per_pocket,
            "binding_energies": self.binding_energies,
        }


    def compute_binding_energy(self, pocket_file_name, smiles, ligand_energy=None):
        binding_energy_calculator = LigandProteinBindingEnergyCalculator(
            pocket_pdb_file=pocket_file_name,
            smiles=smiles,
            ground_state_energy_calculator=self.grand_state_calculator,
            solvent=self.solvent,
            active_space=self.active_space,
        )

        return binding_energy_calculator.compute_binding_energy(
            ligand_energy=ligand_energy)


In [None]:

selected_ligands_after_vetting_2_path = os.path.join(selected_ligands_dir, "selected_ligands_after_vetting_2.json")

if not os.path.isfile(selected_ligands_after_vetting_2_path):
    binding_energies_vetting_module = BindingEnergiesVetter(
        grand_state_calculator=grand_state_calculator,
        ligands=selected_ligands,
        pockets_dir=generated_pockets_dir,
        active_space=active_space,
    )
    
    vetting_results = binding_energies_vetting_module.vet()
    ligands_coords_file_util.save(vetting_results, selected_ligands_after_vetting_2_path)
else:
    print(f"Ligands have already been selected in vetting round 2.\nLoading already selected at: {selected_ligands_after_vetting_2_path}")
    vetting_results = ligands_coords_file_util.load(selected_ligands_after_vetting_2_path)



most_stable_ligand = vetting_results['most_stable_ligand_overall']
# print most stable ligand with details about pocket, ligand energy, binding energy, qed and mw
print(f"Most stable ligand overall:")
print(f"SMILES: {most_stable_ligand_['smiles']}")
print(f"QED: {most_stable_ligand_['qed']}")
print(f"MW: {most_stable_ligand_['mw']}")
print(f"GSE: {most_stable_ligand_['gse']}")
print(f"GSE in solvent: {most_stable_ligand_['gse']}")
print(f"GSE Ligand in pocket: {most_stable_ligand_['gse']}")
print(f"Binding Energy: {most_stable_ligand_['gse']}")
print(f"Pocket: {most_stable_ligand_['pocket']}")


### Custom VQE implementation

In [None]:
# import numpy as np
# from qiskit.algorithms.optimizers import SPSA
# from qiskit_nature.algorithms import MinimumEigensolverResult

In [None]:
# from qiskit_aer import AerSimulator
# from qiskit_nature.second_q.drivers import PySCFDriver
# from qiskit_nature.second_q.transformers import ActiveSpaceTransformer
# from qiskit_nature.second_q.mappers import ParityMapper
# from qiskit_nature.circuit.library import HartreeFock, UCCSD, EfficientSU2
# from qiskit_nature.constants import DistanceUnit
# from qiskit.primitives import EstimatorV2
# from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager

In [None]:
# # This is a custom VQE implementation and Ground State Energy Calculator
# # based on the paper at: https://pubs.acs.org/doi/full/10.1021/acs.jctc.4c01657
# class CustomVQE:
#     def __init__(self, estimator, ansatz, observable, problem, maxiter=250):
#         self.estimator = estimator
#         self.ansatz = ansatz
#         self.observable = observable
#         self.problem = problem
#         self.maxiter = maxiter
#         self.results = []

#     def solve(self):
#         initial_params = np.random.rand(self.ansatz.num_parameters)

#         optimizer = SPSA(
#             maxiter=self.maxiter,
#             callback=self.optimizer_callback
#         )

#         learning, pert = optimizer.calibrate(self.cost_function, initial_params)
#         optimizer.learning_rate = learning
#         optimizer.perturbation = pert

#         result = optimizer.minimize(self.cost_function, initial_params)
#         final_energy = result.fun
#         return final_energy, result.x, self.results

#     def cost_function(self, params):
#         job = self.estimator.run([(self.ansatz, self.observable, params)])
#         exp_val = job.result()[0].data.evs
#         return self.interpret_exp_val(exp_val)

#     def interpret_exp_val(self, exp_val):
#         sol = MinimumEigensolverResult()
#         sol.eigenvalue = np.real(exp_val)
#         return self.problem.interpret_exp_val(sol).total_energies[0]

#     def optimizer_callback(self, ne, params, value, step, accepted):
#         self.results.append(value)
#         print(f"Iteration {len(self.results):03d} - Energy = {value:.6f}")

In [None]:
# class GroundStateEnergyCalculator:
#     def __init__(self, backend=None, hardware_efficient=False, reps=1, entanglement='linear', optimizer_iter=250):
#         self.backend = backend
#         self.hardware_efficient = hardware_efficient
#         self.entanglement = entanglement
#         self.reps = reps
#         self.optimizer_iter = optimizer_iter

#     def calculate_ground_state_energy(self, molecule_coordinates):
#         driver = PySCFDriver(
#             atom=molecule_coordinates,
#             basis="sto3g",
#             charge=0,
#             spin=0,
#             unit=DistanceUnit.ANGSTROM,
#         )
#         es_problem = driver.run()

#         transformer = ActiveSpaceTransformer(num_electrons=2, num_spatial_orbitals=3)
#         reduced_problem = transformer.transform(es_problem)
#         second_q_hamiltonian = reduced_problem.second_q_ops()[0]

#         mapper = ParityMapper(num_particles=reduced_problem.num_particles)
#         qubit_op = mapper.map(second_q_hamiltonian)

#         hf = HartreeFock(
#             num_particles=reduced_problem.num_particles,
#             num_spatial_orbitals=reduced_problem.num_spatial_orbitals,
#             qubit_mapper=mapper
#         )

#         ansatz = UCCSD(
#             reduced_problem.num_spatial_orbitals,
#             reduced_problem.num_particles,
#             initial_state=hf,
#             qubit_mapper=mapper
#         ) if not self.hardware_efficient else EfficientSU2(
#             num_qubits=qubit_op.num_qubits,
#             entanglement=self.entanglement,
#             reps=self.reps
#         )

#         backend = self.backend or AerSimulator()
#         pass_manager = generate_preset_pass_manager(backend=backend, optimization_level=0)
#         isa_ansatz = pass_manager.run(ansatz)
#         isa_observable = qubit_op.apply_layout(isa_ansatz.layout)

#         estimator = EstimatorV2(mode=backend)
#         vqe = CustomVQE(estimator, isa_ansatz, isa_observable, reduced_problem, maxiter=self.optimizer_iter)

#         final_energy, optimal_params, energy_trace = vqe.solve()
#         return final_energy

In [None]:
# !pip freeze > requirements.txt