In [1]:
import conformer_rl
from rdkit import Chem
from rdkit.Chem import AllChem
from conformer_rl.environments.environments import GibbsScoreLogPruningEnv
from conformer_rl.molecule_generation.molecules import mol_from_molFile
from conformer_rl.config.mol_config import MolConfig
import time

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
ITERS = 200

class MMFFMixin:
    def _step(self, action) -> None:
        for idx, tors in enumerate(self.nonring):
            ang = -180 + 60 * action[idx]
            Chem.rdMolTransforms.SetDihedralDeg(self.conf, *tors, float(ang))
        Chem.AllChem.MMFFOptimizeMolecule(self.mol, maxIters=ITERS, nonBondedThresh=10., confId=self.mol.GetNumConformers() - 1)
        self.episode_info['mol'].AddConformer(self.conf, assignId=True)

class CustomEnv(MMFFMixin, GibbsScoreLogPruningEnv):
    pass

In [3]:
start = time.time()
mol = AllChem.MolFromMolFile('8monomers.mol')
mol = AllChem.AddHs(mol)
AllChem.MMFFSanitizeMolecule(mol)
mol_config = MolConfig()
mol_config.mol = mol
mol_config.E0 = 426.1538271531945
mol_config.Z0 = 1.116382947893354
env = CustomEnv(mol_config, 1000)
print(time.time() - start)

0.10844588279724121


In [4]:
# print(mol_config.E0, mol_config.Z0)
# E0: 426.1538271531945, Z0: 1.116382947893354
# print(len(env.nonring))
# 57
from rdkit.Chem import AllChem
from conformer_rl.analysis import drawConformer
drawConformer(mol)


<py3Dmol.view at 0x2ab23a659490>

In [7]:
import numpy as np

ITERS = 100
env.reset()
for i in range(4):
    start = time.time()
    obs, reward, done, info = env.step(np.random.randint(0, 6, 57))
    print("iteration", i, time.time() - start, "seconds", "reward:", reward)
    mol.RemoveAllConformers()
    mol.AddConformer(env.conf)
    view = drawConformer(mol)
    view.show()
    if done:
        env.reset()


iteration 0 0.4878566265106201 seconds reward: -495.0796578593804


iteration 1 0.4832799434661865 seconds reward: 0.0


iteration 2 0.5434741973876953 seconds reward: 0.0


iteration 3 0.5193750858306885 seconds reward: 0.0


In [8]:
ITERS = 200
env.reset()
for i in range(4):
    start = time.time()
    obs, reward, done, info = env.step(np.random.randint(0, 6, 57))
    print("iteration", i, time.time() - start, "seconds", "reward:", reward)
    mol.RemoveAllConformers()
    mol.AddConformer(env.conf)
    view = drawConformer(mol)
    view.show()
    if done:
        env.reset()

iteration 0 0.8170318603515625 seconds reward: -342.67832758457706


iteration 1 0.9213120937347412 seconds reward: 64.79741600534095


iteration 2 0.9858808517456055 seconds reward: 55.228095885219545


iteration 3 0.8858189582824707 seconds reward: 0.0


In [9]:
ITERS = 400
env.reset()
for i in range(4):
    start = time.time()
    obs, reward, done, info = env.step(np.random.randint(0, 6, 57))
    print("iteration", i, time.time() - start, "seconds", "reward:", reward)
    mol.RemoveAllConformers()
    mol.AddConformer(env.conf)
    view = drawConformer(mol)
    view.show()
    if done:
        env.reset()

iteration 0 2.0503287315368652 seconds reward: -302.7649361987071


iteration 1 1.7039501667022705 seconds reward: 120.73838971302325


iteration 2 1.8008055686950684 seconds reward: 0.0


iteration 3 1.8306055068969727 seconds reward: 9.791259572011768e-07


In [10]:
ITERS = 1000
env.reset()
for i in range(4):
    start = time.time()
    obs, reward, done, info = env.step(np.random.randint(0, 6, 57))
    print("iteration", i, time.time() - start, "seconds", "reward:", reward)
    mol.RemoveAllConformers()
    mol.AddConformer(env.conf)
    view = drawConformer(mol)
    view.show()
    if done:
        env.reset()

iteration 0 4.335876703262329 seconds reward: -164.19041758350392


iteration 1 4.276808261871338 seconds reward: 0.0


iteration 2 4.08186936378479 seconds reward: 0.0


iteration 3 3.8592026233673096 seconds reward: 0.0


In [13]:
ITERS = 50
env.reset()
for i in range(4):
    start = time.time()
    obs, reward, done, info = env.step(np.random.randint(0, 6, 57))
    print("iteration", i, time.time() - start, "seconds", "reward:", reward)
    mol.RemoveAllConformers()
    mol.AddConformer(env.conf)
    view = drawConformer(mol)
    view.show()
    if done:
        env.reset()

iteration 0 0.22055721282958984 seconds reward: 2.220446049250313e-16


iteration 1 0.3179352283477783 seconds reward: 0.0


iteration 2 0.31601858139038086 seconds reward: 0.0


iteration 3 0.33186960220336914 seconds reward: 0.0
