In [1]:
import pickle
from ase import Atoms
from rdkit import Chem
from rdkit.Chem import AllChem
import py3Dmol
from xtb.ase.calculator import XTB
from ase.optimize import BFGS
import time
import numpy as np
import pandas as pd
from copy import deepcopy
from rdkit.Geometry import Point3D
import torch
import random
import argparse
from model.bgrl import BGRL
from model.edieggc import EDiEGGC
from data.data import pos_z_to_graph, prepare_line_dgl

In [2]:

parser = argparse.ArgumentParser(description='dipole moment prediction training')
"""Experiment setting."""
parser.add_argument('--lacl', type=bool, default=True, help="True for training LACL, False for training")
parser.add_argument('--finetune', type=bool, default=False, help="")
parser.add_argument('--freeze', type=bool, default=False, help="")
parser.add_argument('--update_moving_average', type=bool, default=True, help="For BGRL")
parser.add_argument('--loss', type=str, default='contrastive+prediction', help="")
parser.add_argument('--num-workers', type=int, default=6, help="Number of workers for dataloader")
parser.add_argument('--dataset', type=str, default='QMugs', help="QM9 or QMugs")
parser.add_argument('--set', type=str, default='src', help="'src' for source domain and 'tgt' for target domain")
parser.add_argument('--target', type=str, default='GFN2:DIPOLE', help="homo, lumo, gap, mu, ...")
parser.add_argument('--geometry', type=str, default='MMFF', help="")
parser.add_argument('--epochs', type=int, default=300, help="")
parser.add_argument('--num-train', type=int, default=65000, help="110000/65000")
parser.add_argument('--num-valid', type=int, default=1500, help="10000/1500")
parser.add_argument('--num-test', type=int, default=1706, help="10829/1706")
parser.add_argument('--batch-size', type=int, default=32, help="")
parser.add_argument('--learning-rate', type=float, default=1e-3, help="")
parser.add_argument('--weight-decay', type=float, default=0, help="")
parser.add_argument('--max-norm', type=float, default=1000.0, help="")
parser.add_argument('--scheduler', type=str, default='plateau', help="")
parser.add_argument('--cutoff', type=float, default=4.0, help="")
parser.add_argument('--device', type=str, default='cuda:1', help="cuda device")
'''Model setting'''
parser.add_argument('--embedding-type', type=str, default='cgcnn', help="")
parser.add_argument('--alignn-layers', type=int, default=4, help="")
parser.add_argument('--gcn-layers', type=int, default=4, help="")
parser.add_argument('--atom-input-features', type=int, default=92, help="")
parser.add_argument('--edge-input-features', type=int, default=80, help="")
parser.add_argument('--triplet-input-features', type=int, default=40, help="")
parser.add_argument('--embedding-features', type=int, default=64, help="")
parser.add_argument('--hidden-features', type=int, default=256, help="")
parser.add_argument('--output-features', type=int, default=1, help="")
args = parser.parse_args([])

In [3]:
with open('data/QMugs/QMugs_20_energy.pkl', 'rb') as f:
    qmugs = pickle.load(f)
    
idx = list(range(len(qmugs)))
random.seed(123)
random.shuffle(idx)
qmugs = [qmugs[i] for i in idx[-args.num_test:]]

In [4]:
# select molecules with 20 heavy atoms
qmugs_20 = []
for i, m in enumerate(qmugs):
    mol = deepcopy(m['Molecule'][0])
    if mol.GetNumHeavyAtoms() == 20:
        qmugs_20.append((m['SMILES'][0], mol, float(m['GFN2:DIPOLE'][0].split('|')[-1])))

In [5]:
# Dihedral rotatable bonds are manually selected
js = [1,17,33, 47,59,66,70,76,87,90,91, 99]
abcd_s = [
    (5,6,8,9),
    (6,7,8,9),
    (5,4,3,2),
    (5,6,7,9),
    (1,3,4,18),
    (11,1,2,3),
    (9,10,11,12),
    (3,2,1,14),
    (5,6,8,9),
    (10,11,12,13),
    (5,6,7,8),
    (9,10,12,13)
]

In [55]:
# MMFF optimization to get model's input geometry and GFN2-xTB optimization to get labels for evaluation
mmff = []
gfn2 = []
imagemols = []
mus = []

for j, (a, b, c, d) in zip(js, abcd_s):
    mol = qmugs_20[j][1]

    # compact conformer
    closed_mol = deepcopy(mol)
    gfn2.append(deepcopy(closed_mol))
    conf = closed_mol.GetConformer()
    pos = conf.GetPositions()
    atoms = [atom.GetAtomicNum() for atom in closed_mol.GetAtoms()]
    charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]
    ase_mol = Atoms(numbers=atoms, positions=pos, charges= charges)
    ase_mol.calc = XTB(method='GFN2-xTB', n_cores=12, verbose=False)
    opt = BFGS(ase_mol)
    opt.run(fmax=0.01)
    newpos = np.array(ase_mol.get_positions())
    for i in range(closed_mol.GetNumAtoms()):
        x,y,z = newpos[i]
        conf.SetAtomPosition(i,Point3D(x,y,z))
    imagemol = deepcopy(closed_mol)
    # imagemol = Chem.RemoveHs(imagemol)
    Chem.Draw.MolToImageFile(imagemol, f'mol_images/{j}_compact.tif', size=(400, 400))
    imagemols.append(imagemol)
    mu = ase_mol.get_dipole_moment()
    mu = np.sqrt(mu[0]**2 + mu[1]**2 + mu[2]**2) * 4.803  # from e⋅Å to Debye
    mus.append(mu)
    with open('mol_images/result.txt', 'a') as f:
        f.write(f'{j}, {qmugs_20[j][0]}, compact, {mu}\n')
    AllChem.MMFFOptimizeMolecule(closed_mol, confId=0)
    mmff.append(deepcopy(closed_mol))

    # open conformer
    open_mol = deepcopy(mol)
    conf = open_mol.GetConformer()
    Chem.rdMolTransforms.SetDihedralDeg(conf, a, b, c, d, 180)
    imagemol = deepcopy(open_mol)
    # imagemol = Chem.RemoveHs(imagemol)
    Chem.Draw.MolToImageFile(imagemol, f'mol_images/{j}_open.tif', size=(400, 400))
    imagemols.append(imagemol)
    pos = conf.GetPositions()
    atoms = [atom.GetAtomicNum() for atom in open_mol.GetAtoms()]
    charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]
    ase_mol = Atoms(numbers=atoms, positions=pos, charges= charges)
    ase_mol.calc = XTB(method='GFN2-xTB', n_cores=12, verbose=False)
    opt = BFGS(ase_mol)
    opt.run(fmax=0.01)
    newpos = np.array(ase_mol.get_positions())
    for i in range(open_mol.GetNumAtoms()):
        x,y,z = newpos[i]
        conf.SetAtomPosition(i,Point3D(x,y,z))
    gfn2.append(deepcopy(open_mol))
    mu = ase_mol.get_dipole_moment()
    mu = np.sqrt(mu[0]**2 + mu[1]**2 + mu[2]**2) * 4.803  # from e⋅Å to Debye
    mus.append(mu)
    with open('mol_images/result.txt', 'a') as f:
        f.write(f'{j}, {qmugs_20[j][0]}, open, {mu}\n')
    AllChem.MMFFOptimizeMolecule(open_mol, confId=0)
    mmff.append(deepcopy(open_mol))


      Step     Time          Energy         fmax
BFGS:    0 06:24:40    -1579.845110        0.0204
BFGS:    1 06:24:40    -1579.845119        0.0176
BFGS:    2 06:24:40    -1579.845124        0.0057
      Step     Time          Energy         fmax
BFGS:    0 06:24:40    -1579.948314        0.9838
BFGS:    1 06:24:40    -1579.961953        0.5074
BFGS:    2 06:24:41    -1579.971002        0.3342
BFGS:    3 06:24:41    -1579.979040        0.2646
BFGS:    4 06:24:41    -1579.982844        0.2545
BFGS:    5 06:24:41    -1579.986038        0.2212
BFGS:    6 06:24:41    -1579.989535        0.1472
BFGS:    7 06:24:42    -1579.991507        0.1343
BFGS:    8 06:24:42    -1579.993082        0.1222
BFGS:    9 06:24:42    -1579.994684        0.1383
BFGS:   10 06:24:42    -1579.996848        0.1563
BFGS:   11 06:24:42    -1579.998838        0.1202
BFGS:   12 06:24:42    -1580.000308        0.0788
BFGS:   13 06:24:42    -1580.001341        0.0821
BFGS:   14 06:24:43    -1580.002243        0.0714
BF

In [56]:
lacl = BGRL(args)
lacl.load_state_dict(torch.load('./ckpts/qmugs_20_mmff_mu_lacl'))
lacl.eval()
source_finetune = EDiEGGC(args)
source_finetune.load_state_dict(torch.load('./ckpts/qmugs_20_mmff_mu_source_finetune'))
source_finetune.eval()
target_finetune = EDiEGGC(args)
target_finetune.load_state_dict(torch.load('./ckpts/qmugs_20_mmff_mu_target_finetune'))
target_finetune.eval()

EDiEGGC(
  (encoder): Encoder(
    (atom_embedding): MLPLayer(
      (layer): Sequential(
        (0): Linear(in_features=92, out_features=256, bias=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
      )
    )
    (edge_embedding): Sequential(
      (0): RBFExpansion()
      (1): MLPLayer(
        (layer): Sequential(
          (0): Linear(in_features=80, out_features=64, bias=True)
          (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU()
        )
      )
      (2): MLPLayer(
        (layer): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU()
        )
      )
    )
    (angle_embedding): Sequential(
      (0): RBFExpansion()
      (1): MLPLayer(
        (layer): Sequential(
          (0): Linear(in_features=40

In [57]:
conformation_test = []
for mol_g, mol_m, label in zip(gfn2, mmff, mus):
    item = {}
    mol = mol_g
    conf = mol.GetConformer()
    pos = conf.GetPositions()
    z = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()])
    g = pos_z_to_graph(pos, z, args.cutoff)
    lg = prepare_line_dgl(g)
    v, x, y, z = lacl.online_encoder(g, lg)
    item['lacl_pred_gfn2'] = lacl.online_decoder(g, lg, v, x, y, z).item()
    item['source_pred_gfn2'] = source_finetune(g, lg)[0].item()
    item['target_pred_gfn2'] = target_finetune(g, lg)[0].item()
    mol = mol_m
    conf = mol.GetConformer()
    pos = conf.GetPositions()
    z = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()])
    g = pos_z_to_graph(pos, z, args.cutoff)
    lg = prepare_line_dgl(g)
    v, x, y, z = lacl.online_encoder(g, lg)
    item['lacl_pred_mmff']  = lacl.online_decoder(g, lg, v, x, y, z).item()
    item['source_pred_mmff'] = source_finetune(g, lg)[0].item()
    item['target_pred_mmff'] = target_finetune(g, lg)[0].item()

    item['label'] = label
    conformation_test.append(item)

In [58]:
result_df = pd.DataFrame(conformation_test)
lacl_mae_mmff = (result_df['lacl_pred_mmff'] - result_df['label']).apply(lambda x: abs(x))
source_mae_mmff = (result_df['source_pred_mmff'] - result_df['label']).apply(lambda x: abs(x))
target_mae_mmff = (result_df['target_pred_mmff'] - result_df['label']).apply(lambda x: abs(x))
print('lacl_mae_compact', lacl_mae_mmff[[2*i for i in range(12)]].mean())
print('source_mae_compact', source_mae_mmff[[2*i for i in range(12)]].mean())
print('target_mae_compact', target_mae_mmff[[2*i for i in range(12)]].mean())
print('lacl_mae_open', lacl_mae_mmff[[2*i+1 for i in range(12)]].mean())
print('source_mae_open', source_mae_mmff[[2*i+1 for i in range(12)]].mean())
print('target_mae_open', target_mae_mmff[[2*i+1 for i in range(12)]].mean())
result_df

lacl_mae_compact 0.36606821664675765
source_mae_compact 0.5167488549663771
target_mae_compact 0.31999732836319283
lacl_mae_open 0.39939030500076517
source_mae_open 0.560853316116393
target_mae_open 0.4841027898557643


Unnamed: 0,lacl_pred_gfn2,source_pred_gfn2,target_pred_gfn2,lacl_pred_mmff,source_pred_mmff,target_pred_mmff,label
0,5.298273,6.129177,4.797135,6.096734,6.273084,6.140422,6.359372
1,4.613882,5.989444,4.192105,5.600502,6.053217,5.913217,5.748774
2,4.332472,4.855957,4.542428,5.477022,5.844475,5.296608,5.196109
3,4.710709,5.383428,4.863209,5.238517,6.107234,5.525047,5.77159
4,4.675867,4.764454,4.420328,5.200075,4.868182,4.931461,4.791993
5,4.709696,5.033743,4.609104,5.158689,4.920982,4.827761,5.112519
6,5.884977,6.052667,5.313303,6.289101,6.572124,5.644134,5.726808
7,9.971092,10.271099,9.5532,10.26364,10.260971,9.725629,10.501991
8,4.783034,4.381783,4.57937,4.931899,4.209332,4.487748,4.259352
9,3.539004,3.893847,3.942779,3.719813,3.728625,3.944273,3.512984


In [86]:
view = py3Dmol.view(width=800, height=600, data=Chem.MolToMolBlock(imagemols[0]), style={"stick": {}, "sphere": {"scale": 0.3}})
view.zoomTo()

<py3Dmol.view at 0x7fb7fc516df0>