In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import copy
import time
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

# Add the project root to the path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from utils.utils import set_seed

In [3]:
import dgl
import flowmol

In [4]:
def sampling(config: OmegaConf, model: flowmol.FlowMol, device: torch.device):
    new_molecules = model.sample_random_sizes(
        n_molecules = config.num_samples, 
        n_timesteps = config.num_integration_steps + 1, 
        device = device,
    )
    return new_molecules

In [5]:
def setup_gen_model(flow_model: str, device: torch.device):
    # Load - Flow Model
    gen_model = flowmol.load_pretrained(flow_model)
    gen_model.to(device)
    return gen_model

In [6]:
import py3Dmol
from rdkit import Chem
from rdkit.Chem import AllChem

def print_molecule(mol):
    # Convert RDKit molecule to 3D coordinates
    mol = Chem.AddHs(mol)
    AllChem.EmbedMolecule(mol, randomSeed=42)
    AllChem.UFFOptimizeMolecule(mol)

    # Generate PDB block
    Chem.SanitizeMol(mol)
    Chem.AssignAtomChiralTagsFromStructure(mol)
    Chem.AssignStereochemistryFrom3D(mol)

### Start

In [7]:
config = OmegaConf.create({
    "seed": 42,
    "flow_model": "qm9_ctmc",
    "num_samples": 5,
    "num_integration_steps": 100,
})

In [None]:
# Setup - Seed and device
set_seed(config.seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
# Setup - Gen Model
gen_model = setup_gen_model(config.flow_model, device=device)

In [11]:
generated_molecules = sampling(config, gen_model, device)

  assert input.numel() == input.storage().size(), (


In [12]:
gen_rdkit_mols = []
gen_smiles = []
for mol in generated_molecules:
    gen_rdkit_mols.append(mol.rdkit_mol)
    gen_smiles.append(mol.smiles)

gen_smiles

['[H]C#CC1(C([H])=O)[NH+](C([H])([H])C([H])([H])[H])C1([H])[H]',
 '[H]C(=O)C1([H])C([H])([H])C2(C([H])([H])C([H])([H])[H])C([H])([H])C12[H]',
 '[H]C1=C(C([H])([H])C([H])([H])[H])C2(C([H])([H])[H])C([H])([H])C2([H])C1([H])[H]',
 '[H]OC([H])([H])C1([H])OC(=O)C2([H])OC21[H]',
 '[H]C1([H])C(=O)C2([H])OC2(C([H])([H])[H])C1=O']

#### SA Score

In [13]:
from molscore.scoring_functions.SA_Score import sascorer
from rdkit import Chem

OpenEye functions: currently unavailable due to the following: No module named 'openeye'
PoseCheck metrics: currently unavailable due to the following: No module named 'prolif'


In [14]:
def get_sa_score(mol):
    Chem.GetSSSR(mol)
    Chem.SanitizeMol(mol)
    score = sascorer.calculateScore(mol)
    return score

In [15]:
sa_scores = []
for tmp in gen_rdkit_mols:
    Chem.GetSSSR(tmp)
    Chem.SanitizeMol(tmp)
    score = sascorer.calculateScore(tmp)
    sa_scores.append(score)

sa_scores

[7.511121331446639,
 7.7398463812059655,
 7.7752512797078035,
 7.653818612869355,
 7.0971109891641975]

In [16]:
from posebusters import PoseBusters

In [17]:
buster = PoseBusters(config="mol")
df = buster.bust(gen_rdkit_mols, None, None, full_report=False)
print(df.shape)
df

(5, 10)


Unnamed: 0_level_0,Unnamed: 1_level_0,mol_pred_loaded,sanitization,inchi_convertible,all_atoms_connected,bond_lengths,bond_angles,internal_steric_clash,aromatic_ring_flatness,double_bond_flatness,internal_energy
file,molecule,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
<rdkit.Chem.rdchem.Mol object at 0x333502d60>,mol_at_pos_0,True,True,True,True,True,True,True,True,True,True
<rdkit.Chem.rdchem.Mol object at 0x3334ef660>,mol_at_pos_0,True,True,True,True,True,True,True,True,True,True
<rdkit.Chem.rdchem.Mol object at 0x3334ef580>,mol_at_pos_0,True,True,True,True,True,True,True,True,True,True
<rdkit.Chem.rdchem.Mol object at 0x3334ef3c0>,mol_at_pos_0,True,True,True,True,True,True,True,True,True,True
<rdkit.Chem.rdchem.Mol object at 0x3334ef350>,mol_at_pos_0,True,True,True,True,True,True,True,True,True,True


In [18]:
df.columns

Index(['mol_pred_loaded', 'sanitization', 'inchi_convertible',
       'all_atoms_connected', 'bond_lengths', 'bond_angles',
       'internal_steric_clash', 'aromatic_ring_flatness',
       'double_bond_flatness', 'internal_energy'],
      dtype='object')

In [19]:
df.iloc[0]

mol_pred_loaded           True
sanitization              True
inchi_convertible         True
all_atoms_connected       True
bond_lengths              True
bond_angles               True
internal_steric_clash     True
aromatic_ring_flatness    True
double_bond_flatness      True
internal_energy           True
Name: (<rdkit.Chem.rdchem.Mol object at 0x333502d60>, mol_at_pos_0), dtype: bool

#### Visualize the Molecules

In [20]:
import py3Dmol

In [None]:
from rdkit import Chem
import py3Dmol

# Convert RDKit Mol to PDB blocks
# Only take up to 6 mols
pdb_blocks = [Chem.MolToMolBlock(mol) for mol in gen_rdkit_mols[:6]]  

# Visualize using py3Dmol
viewer = py3Dmol.view(width=600, height=600, viewergrid=(2, 3))

for idx, pdb in enumerate(pdb_blocks):
    row = idx // 3
    col = idx % 3
    viewer.addModel(pdb, "mol", viewer=(row, col))

viewer.setStyle({"stick": {}, "sphere": {"scale": 0.3}})
viewer.zoomTo()
viewer.show()
