# Stable species conformer search
Leverage ETKDG for stochastic conformer generation

Use this as a base for ML conformer generation

The idea is to have modular methods for each step, which are currently hardcoded. This includes:
- initial conformer embedding (ETKDG, GeoMol)
- optimization/energy (MMFF, UFF, GFN-FF, GFN2-xTB)
- pruning (torsion fingerprints, CREGEN)
- convergence metrics (conformational entropy/partition function)

In [37]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdmc import OpenBabelFF, RDKitFF, optimize_mol
from rdmc.mol import RDKitMol
from rdmc.view import mol_viewer, interactive_conformer_viewer
import numpy as np
import copy


T = 298  # K
R = 0.0019872  # kcal/(K*mol)
HARTREE_TO_KCAL_MOL = 627.503

In [38]:
import os.path as osp
import yaml
import torch
from geomol.model import GeoMol
from geomol.featurization import featurize_mol_from_smiles
from torch_geometric.data import Batch
from geomol.inference import construct_conformers

class GeoMolEmbedder:
    def __init__(self, trained_model_dir):
        
        # TODO: add option of pre-pruning geometries using alpha values
        # TODO: add option of changing "temperature" each iteration to sample diverse geometries
        
        with open(osp.join(trained_model_dir, "model_parameters.yml")) as f:
            model_parameters = yaml.full_load(f)
        model = GeoMol(**model_parameters)

        state_dict = torch.load(osp.join(trained_model_dir, "best_model.pt"), map_location=torch.device('cpu'))
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        self.model = model
        self.tg_data = None
        
    def __call__(self, smiles, n_conformers, std=1.0):

        # set "temperature"
        self.model.random_vec_std = std

        # featurize data and run GeoMol
        if self.tg_data is None:
            self.tg_data = featurize_mol_from_smiles(smiles, dataset="drugs")
        data = Batch.from_data_list([self.tg_data])  # need to run this bc of dumb internal GeoMol processing
        self.model(data, inference=True, n_model_confs=n_conformers)

        # process predictions
        n_atoms = self.tg_data.x.size(0)
        model_coords = construct_conformers(data, self.model).double().cpu().detach().numpy()
        split_model_coords = np.split(model_coords, n_conformers, axis=1)

        # package in mol and return
        mol = RDKitMol.FromSmiles(smiles)
        mol.EmbedMultipleNullConfs(n=n_conformers)
        [mol.SetPositions(coords=x.squeeze(axis=1), id=i) for i, x in enumerate(split_model_coords)]

        return mol
    
    
class ETKDGEmbedder:
    def __init__(self):
        self.mol = None
        
    def __call__(self, smiles, n_conformers):
        if self.mol is None:
            self.mol = RDKitMol.FromSmiles(smiles)
            
        mol = self.mol.Copy()    
        mol.EmbedMultipleConfs(n_conformers)
        return mol
    
    
class RandomEmbedder:
    def __init__(self):
        self.mol = None
        
    def __call__(self, smiles, n_conformers):
        if self.mol is None:
            self.mol = RDKitMol.FromSmiles(smiles)
            
        mol = self.mol.Copy()    
        mol.EmbedMultipleNullConfs(n_conformers)
        return mol

In [39]:
class MMFFOptimizer:
    def __init__(self, method="rdkit"):
        if method == "rdkit":
            self.ff = RDKitFF()
        elif method == "openbabel":
            raise NotImplementedError
    
    def __call__(self, mol):
        
        self.ff.setup(mol.Copy())
        results = self.ff.optimize_confs()
        _, energies = zip(*results)  # kcal/mol
        opt_mol = self.ff.get_optimized_mol()
        
        current_mol_data = []
        for c_id, energy in zip(range(opt_mol.GetNumConformers()), energies):
            conf = copy.copy(opt_mol.GetConformer(c_id))
            positions = conf.GetPositions()
            current_mol_data.append({"positions": positions,
                                     "conf": conf,
                                     "energy": energy})
        return current_mol_data
    
    
class XTBOptimizer:
    def __init__(self, method="gff"):
        self.method = method
    
    def __call__(self, mol):
        
        new_mol = mol.Copy()
        new_mol._mol.RemoveAllConformers()
        new_mol.EmbedNullConformer()
        conformers = mol.GetAllConformers()
        
        current_mol_data = []
        for c_id, c in enumerate(conformers):
            pos = c.GetPositions()
            new_mol.SetPositions(pos)
            try:
                _, opt_mol = run_xtb_calc(new_mol, opt=True, return_optmol=True, method=self.method)
            except ValueError as e:
                print(e)
                continue
            
            conf = opt_mol.Copy().GetConformer(0)
            positions = conf.GetPositions()
            energy = float(opt_mol.GetProp('total energy / Eh')) * HARTREE_TO_KCAL_MOL  # kcal/mol (TODO: check)
            current_mol_data.append({"positions": positions,
                                     "conf": conf,
                                     "energy": energy})

        return current_mol_data

In [40]:
class SCGMetric:
    def __init__(self, metric="entropy", window=5, threshold=0.01, T=298):
        self.metric = metric
        self.window = window
        self.threshold = threshold
        self.T = T
        self.metric_history = []
        
    def calculate_metric(self, mol_data):
        
        if self.metric == "entropy":
            metric_val = self.calculate_entropy(mol_data)
        
        elif self.metric == "partition function":
            metric_val = self.calculate_partition_function(mol_data)
            
        else:
            raise NotImplementedError(f"Metric {self.metric} is not supported.")
            
        self.metric_history.append(metric_val)
        
        
    def check_metric(self):
        
        min_metric = np.min(self.metric_history[-self.window:])
        max_metric = np.max(self.metric_history[-self.window:])
        change = (max_metric-min_metric)/min_metric
        return True if change <= self.threshold else False
        
        
    def calculate_entropy(self, mol_data):
        
        energies = np.array([c["energy"] for c in mol_data])
        energies = energies-energies.min()
        _prob = np.exp(-energies / (R*self.T))
        prob = _prob / _prob.sum()
        entropy = -R * np.sum(prob * np.log(prob))
        return entropy
    
    
    def calculate_partition_function(self, mol_data):
        
        energies = np.array([c["energy"] for c in mol_data])
        energies = energies-energies.min()
        prob = np.exp(-energies / (R*self.T))
        partition_fn = 1 + prob.sum()
        return partition_fn

In [51]:
rad_angle_compare = lambda x,y: np.abs(np.arctan2(np.sin(x-y), np.cos(x-y))) * 180 / np.pi
torsion_list_compare = lambda c1_ts, c2_ts: [rad_angle_compare(t1, t2) for t1, t2 in zip(c1_ts, c2_ts)]

class TorsionPruner:
    """
    Prune conformers based on torsion angle criteria.
    This method uses a mean and max criteria to prune conformers:
    A conformer is considered unique if it satisfies either of the following criteria:
        mean difference of all torsion angles > mean_chk_threshold
        max difference of all torsion angles > max_chk_threshold
    New conformers are compared to all conformers that have already been deemed unique
    """
    
    def __init__(self, mean_chk_threshold=10, max_chk_threshold=20):
        
        self.mean_chk_threshold = mean_chk_threshold
        self.max_chk_threshold = max_chk_threshold
        self.torsions_list = None
        
    def initialize_torsions_list(self, smiles):
        
        mol = RDKitMol.FromSmiles(smiles)
        mol.EmbedNullConformer()
        self.torsions_list = mol.GetConformer().GetTorsionalModes()
        
    def calculate_torsions(self, mol_data):
            
        for conf_data in mol_data:
            conf = conf_data["conf"]
            torsions = np.array([conf.GetTorsionDeg(t) for t in self.torsions_list]) % 360
            conf_data.update({"torsions": torsions})
        return mol_data
        
    def __call__(self, current_mol_data, unique_mol_data):
        
        # calculate torsions for new mols
        current_mol_data = self.calculate_torsions(current_mol_data)
        
        # prep comparison and compute torsion matrix
        n_unique_mols = max(1, len(unique_mol_data))  # set to 1 if 0
        mols_list = unique_mol_data + current_mol_data
        torsion_matrix = np.stack([c["torsions"] for c in mols_list])
        torsion_matrix_rad = torsion_matrix * np.pi / 180
        n_confs = len(mols_list)
        conf_ids = np.arange(n_confs).tolist()
        
        # start comparison at new mols
        for i in conf_ids[n_unique_mols:]:

            c_torsions = torsion_matrix_rad[i]  # torsions of this conformer
            c_before_torsions = torsion_matrix_rad[:i]  # torsions of all other conformers already compared

            # mean and max criteria checks
            comp = np.array([torsion_list_compare(c_torsions, ct) for ct in c_before_torsions])
            chk1 = (np.mean(comp, axis=1) < self.mean_chk_threshold).any()
            chk2 = (np.max(comp, axis=1) < self.max_chk_threshold).any()
            
            # remove conformer if either check is satisfied
            if chk1 or chk2:
                conf_ids.remove(i)
        
        # update mols and sort by energy
        updated_unique_mol_data = sorted([mols_list[i] for i in conf_ids], key=lambda x: x["energy"])
        return updated_unique_mol_data
    
    
class CRESTPruner:
    def __init__(self):
        pass
    
    def __call__(self, current_mol_data, unique_mol_data):
        
        print("current:", len(unique_mol_data), "new:", len(current_mol_data))
        all_mol_data = unique_mol_data + current_mol_data
        updated_unique_mol_data = run_cre_check(all_mol_data)
        updated_unique_mol_data = sorted(updated_unique_mol_data, key=lambda x: x["energy"])
        return updated_unique_mol_data

In [52]:
class StochasticConformerGenerator:
    def __init__(self, smiles, conformer_embedder, optimizer, pruner,
                 metric, min_iters=5, max_iters=10, optimize=True):
        super(StochasticConformerGenerator, self).__init__()

        self.smiles = smiles
        self.conformer_embedder = conformer_embedder
        self.optimizer = optimizer
        self.pruner = pruner
        self.metric = metric
        
        self.mol = RDKitMol.FromSmiles(smiles)
        self.unique_mol_data = []
        self.iter = 0
        self.min_iters = min_iters
        self.max_iters = max_iters
        self.optimize = optimize
        
        if isinstance(self.pruner, TorsionPruner):
            self.pruner.initialize_torsions_list(smiles)

    def calculate_energy(self, mol, unique_mols):
        
        # TODO: figure out what to do w this (only used if optimize is set to False)
        ff = RDKitFF()
        for c in unique_mols:
            if np.isnan(c["energy"]):
                ff.setup(mol.Copy(), conf_id=c["conf_id"])  # TODO: fix this bc not using conf_id anymore
                energy = ff.get_energy()
                c.update({"energy": energy})  # kJ

        return unique_mols
    
    def __call__(self, n_conformers_per_iter):
        
        print(f"Generating conformers for {self.smiles}")
        for it in range(self.max_iters):
            self.iter += 1
            
            print(f"\nIteration {self.iter}: embedding {n_conformers_per_iter} initial guesses...")
            initial_mol = self.conformer_embedder(self.smiles, n_conformers_per_iter)
            
            if self.optimize:
                print(f"Iteration {self.iter}: optimizing initial guesses...")
                opt_mol_data = self.optimizer(initial_mol)
            else:
                # TODO
                opt_mol_data = []
                for c_id in range(initial_mol.GetNumConformers()):
                    conf = copy.copy(initial_mol.GetConformer(c_id))
                    positions = conf.GetPositions()
                    opt_mol_data.append({"positions": positions,
                                         "conf": conf,
                                         "energy": np.nan})
                opt_mol_data = self.calculate_energy(initial_mol, opt_mol_data)
                
            # check for failures
            if len(opt_mol_data) == 0:
                print("Failed to optimize any of the embedded conformers")
                continue
            
            print(f"Iteration {self.iter}: pruning conformers...")
            unique_mol_data = self.pruner(opt_mol_data, self.unique_mol_data)
            self.metric.calculate_metric(unique_mol_data)
            self.unique_mol_data = unique_mol_data
            
            if it < self.min_iters:
                continue
                
            if self.metric.check_metric():
                print(f"Iteration {self.iter}: stop crietria reached")
                return unique_mol_data
            
        print(f"Iteration {self.iter}: max iterations reached")
        return unique_mol_data

In [66]:
# conformer_embedder = RandomEmbedder()
# conformer_embedder = ETKDGEmbedder()
conformer_embedder = GeoMolEmbedder("../../GeoMol/trained_models/drugs/")

# optimizer = MMFFOptimizer()
optimizer = XTBOptimizer(method="gff")

# pruner = TorsionPruner(mean_chk_threshold=10, max_chk_threshold=20)
pruner = CRESTPruner()

metric = SCGMetric(metric="entropy", window=5, threshold=0.01)

# O=C(/C=C/c1ccco1)Nc1ccc(Cl)c(S(=O)(=O)N2CCOCC2)c1
scg = StochasticConformerGenerator(
    smiles="c1ccccc1",
    conformer_embedder=conformer_embedder,
    optimizer=optimizer,
    pruner=pruner,
    metric=metric,
    max_iters=100,
    optimize=True
)
n_conformers_per_iter = 50
unique_conformers = scg(n_conformers_per_iter)
print(len(unique_conformers), scg.metric.metric_history[-1])

Generating conformers for c1ccccc1

Iteration 1: embedding 50 initial guesses...
Iteration 1: optimizing initial guesses...
Iteration 1: pruning conformers...
current: 0 new: 50
Removed 46 duplicate conformers with cregen

Iteration 2: embedding 50 initial guesses...
Iteration 2: optimizing initial guesses...
Iteration 2: pruning conformers...
current: 4 new: 50
Removed 49 duplicate conformers with cregen

Iteration 3: embedding 50 initial guesses...
Iteration 3: optimizing initial guesses...
Iteration 3: pruning conformers...
current: 5 new: 50
Removed 46 duplicate conformers with cregen

Iteration 4: embedding 50 initial guesses...
Iteration 4: optimizing initial guesses...
Iteration 4: pruning conformers...
current: 9 new: 50
Removed 53 duplicate conformers with cregen

Iteration 5: embedding 50 initial guesses...
Iteration 5: optimizing initial guesses...
Iteration 5: pruning conformers...
current: 6 new: 50
Removed 49 duplicate conformers with cregen

Iteration 6: embedding 50 ini

Iteration 44: pruning conformers...
current: 10 new: 50
Removed 52 duplicate conformers with cregen

Iteration 45: embedding 50 initial guesses...
Iteration 45: optimizing initial guesses...
Iteration 45: pruning conformers...
current: 8 new: 50
Removed 50 duplicate conformers with cregen

Iteration 46: embedding 50 initial guesses...
Iteration 46: optimizing initial guesses...
Iteration 46: pruning conformers...
current: 8 new: 50
Removed 49 duplicate conformers with cregen

Iteration 47: embedding 50 initial guesses...
Iteration 47: optimizing initial guesses...
Iteration 47: pruning conformers...
current: 9 new: 50
Removed 49 duplicate conformers with cregen

Iteration 48: embedding 50 initial guesses...
Iteration 48: optimizing initial guesses...
Iteration 48: pruning conformers...
current: 10 new: 50
Removed 49 duplicate conformers with cregen

Iteration 49: embedding 50 initial guesses...
Iteration 49: optimizing initial guesses...
Iteration 49: pruning conformers...
current: 11 

KeyboardInterrupt: 

0

In [70]:
final_mol = scg.mol.Copy()
[final_mol._mol.AddConformer(c["conf"].ToConformer(), assignId=True) for c in scg.unique_mol_data];

In [71]:
interactive_conformer_viewer(final_mol, viewer_size=(800, 800), atom_index=True)

interactive(children=(IntSlider(value=0, description='confId', max=10), Output()), _dom_classes=('widget-inter…

<function rdmc.view.interactive_conformer_viewer.<locals>.<lambda>(confId)>

In [46]:
CREST_BINARY = os.path.join(os.environ.get("CONDA_PREFIX"), "bin", "crest")

def make_xyz_text(rd_mol,
                  comment):

    atoms = [i for i in rd_mol.GetAtoms()]
    num_atoms = len(atoms)
    pos = rd_mol.GetConformers()[0].GetPositions()

    lines = [str(num_atoms), comment]

    for atom, this_pos in zip(atoms, pos):
        line = "%s %.8f %.8f %.8f " % (atom.GetSymbol(),
                                       this_pos[0], this_pos[1], this_pos[2])
        lines.append(line)

    text = "\n".join(lines)
    return text


def write_confs_xyz(confs, path):

    text = ""
    for i, conf in enumerate(confs):
        rd_mol = conf["conf"].GetOwningMol()
        energy = conf['energy']
        comment = "%.8f !CONF%d" % (energy, i + 1)

        this_text = make_xyz_text(rd_mol=rd_mol,
                                  comment=comment)

        if i != 0:
            text += "\n"
        text += this_text

    with open(path, 'w') as f:
        f.write(text)


def read_unique(job_dir):
    path = os.path.join(job_dir, "enso.tags")
    with open(path, 'r') as f:
        lines = f.readlines()
    unique_idx = []
    for line in lines:
        split = line.strip().split()
        if not split:
            continue

        idx = split[-1].split("!CONF")[-1]
        # means something went wrong
        if not idx.isdigit():
            return

        unique_idx.append(int(idx) - 1)

    return unique_idx

        
def run_cre_check(confs):

    temp_dir = tempfile.mkdtemp()

    logfile = os.path.join(temp_dir, "xtb.log")
    confs_path = os.path.join(temp_dir, "confs.xyz")
    conf_0_path = os.path.join(temp_dir, "conf_0.xyz")
    cregen_out = os.path.join(temp_dir, "cregen.out")

    write_confs_xyz(confs, path=confs_path)
    write_confs_xyz(confs[:1], path=conf_0_path)

    with open(logfile, "w") as f:
        xtb_run = subprocess.run(
                    [
                        CREST_BINARY,
                        conf_0_path,
                        "--cregen",
                        confs_path,
                        "--ethr", "0.05", "--rthr", "0.125", "--bthr", "0.01", "--ewin", "10000", "--enso",
                        ">",
                        cregen_out,
                    ],
                    stdout=f,
                    stderr=subprocess.STDOUT,
                    cwd=temp_dir,
                    env=XTB_ENV,
                )

    if xtb_run.returncode != 0:
        error_out = os.path.join(temp_dir, "xtb.log")
        raise ValueError(f"xTB calculation failed. See {error_out} for details.")

    unique_ids = read_unique(temp_dir)
    updated_confs = [confs[i] for i in unique_ids]
    #rmtree(temp_dir)
    
    ### DEBUG ###
    # num_removed = len(confs) - len(unique_ids)
    # plural = 's' if num_removed > 1 else ''
    # print("Removed %d duplicate conformer%s with cregen" % (num_removed, plural))
    ### DEBUG ###

    return updated_confs

In [8]:
import json
import os
from shutil import rmtree
import subprocess
import tempfile

import numpy as np
from openbabel import pybel

from rdmc.external.xtb.utils import (
    ATOM_ENERGIES_XTB,
    ATOMNUM_TO_ELEM,
    AU_TO_DEBYE,
    EV_TO_HARTREE,
    UTILS_PATH,
    XTB_BINARY,
)

XTB_INPUT_FILE = os.path.join(UTILS_PATH, "xtb.inp")
XTB_ENV = {
    "OMP_STACKSIZE": "1G",
    "OMP_NUM_THREADS": "1",
    "OMP_MAX_ACTIVE_LEVELS": "1",
    "MKL_NUM_THREADS": "1",
}


def read_xtb_json(json_file, mol):
    """Reads JSON output file from xTB.
    Parameters
    ----------
    json_file : str
        path to output file
    mol : pybel molecule object
        molecule object, needed to compute atomic energy
    Returns
    -------
    dict
        dictionary of xTB properties
    """

    with open(json_file, "r") as f:
        data = json.load(f)
    E_homo, E_lumo = get_homo_and_lumo_energies(data)
    atoms = [ATOMNUM_TO_ELEM[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
    atomic_energy = sum([ATOM_ENERGIES_XTB[atom] for atom in atoms])
    props = {
        "E_form": data["total energy"] - atomic_energy,  # already in Hartree
        "E_homo": E_homo * EV_TO_HARTREE,
        "E_lumo": E_lumo * EV_TO_HARTREE,
        "E_gap": data["HOMO-LUMO gap/eV"] * EV_TO_HARTREE,
        "dipole": np.linalg.norm(data["dipole"]) * AU_TO_DEBYE,
        "charges": data["partial charges"],
    }
    return props


def get_homo_and_lumo_energies(data):
    """Extracts HOMO and LUMO energies.
    Parameters
    ----------
    data : dict
        dictionary from xTB JSON output
    Returns
    -------
    tuple(float)
        HOMO/LUMO energies in eV
    Raises
    ------
    ValueError
        in case of unpaired electrons (not supported)
    """
    if data["number of unpaired electrons"] != 0:
        raise ValueError("Unpaired electrons are not supported.")
    num_occupied = (
        np.array(data["fractional occupation"]) > 1e-6
    ).sum()  # number of occupied orbitals; accounting for occassional very small values
    E_homo = data["orbital energies/eV"][num_occupied - 1]  # zero-indexing
    E_lumo = data["orbital energies/eV"][num_occupied]
    return E_homo, E_lumo


def get_wbo(wbo_file):
    """Reads WBO output file from xTB and generates a dictionary with the results. 
    Parameters
    ----------
    wbo_file : str
        path to xTB wbo output file
    Returns
    -------
    list
        list with Wiberg bond orders (only covalent bonds)
    """
    with open(wbo_file, "r") as f:
        lines = [elem.rstrip("\n") for elem in f.readlines()]
    tmp = [
        [int(line[:12]) - 1, int(line[12:24]) - 1, float(line[24:])] for line in lines
    ]
    wbo_dict = {f"{min((a1, a2))}-{max((a1, a2))}": wbo for a1, a2, wbo in tmp}
    return wbo_dict


def run_xtb_calc(mol, opt=False, return_optmol=False, method="gfn2"):
    """Runs xTB single-point calculation with optional geometry optimization.
    Parameters
    ----------
    mol : pybel molecule object
        assumes hydrogens are present
    opt : bool, optional
        Whether to optimize the geometry, by default False
    return_optmol : bool, optional
        Whether to return the optimized molecule, in case optimization was requested, by default False
    Returns
    -------
    dict
        Molecular properties as computed by GFN2-xTB (formation energy, HOMO/LUMO/gap energies, dipole, atomic charges)
    Raises
    ------
    ValueError
        If xTB calculation yield a non-zero return code.
    """

    if return_optmol and not opt:
        LOGGER.info(
            "Can't have `return_optmol` set to True with `opt` set to False. Setting the latter to True now."
        )
        opt = True

    xtb_command = "--opt" if opt else ""
    method = "--" + method
    temp_dir = tempfile.mkdtemp()
    logfile = os.path.join(temp_dir, "xtb.log")
    xtb_out = os.path.join(temp_dir, "xtbout.json")
    xtb_wbo = os.path.join(temp_dir, "wbo")

    sdf_path = os.path.join(temp_dir, "mol.sdf")
    mol.ToSDFFile(sdf_path)
    
    with open(logfile, "w") as f:
        xtb_run = subprocess.run(
            [
                XTB_BINARY,
                sdf_path,
                xtb_command,
                method,
                "--input",
                XTB_INPUT_FILE,
                "--chrg",
                str(mol.GetFormalCharge()),
                "--wbo",
                "--json",
                "true",
            ],
            stdout=f,
            stderr=subprocess.STDOUT,
            cwd=temp_dir,
            env=XTB_ENV,
        )
    if xtb_run.returncode != 0:
        error_out = os.path.join(temp_dir, "xtb.log")
        raise ValueError(f"xTB calculation failed. See {error_out} for details.")

    else:
        if method == "--gff":
            opt_mol = RDKitMol.FromFile(os.path.join(temp_dir, "xtbopt.sdf"))[0]
            rmtree(temp_dir)
            return (None, opt_mol) if return_optmol else None
        props = read_xtb_json(xtb_out, mol)
        if return_optmol:
            opt_mol = RDKitMol.FromFile(os.path.join(temp_dir, "xtbopt.sdf"))[0]
            # opt_mol = next(pybel.readfile("sdf", os.path.join(temp_dir, "xtbopt.sdf")))
        props.update({"wbo": get_wbo(xtb_wbo)})
        rmtree(temp_dir)
        return (props, opt_mol) if return_optmol else props