In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../../src')
from rdkit import Chem
from dfs_transformer.utils import DFSCode2Graph, Graph2Mol, isValid, Smiles2DFSCode, DFSCode2Smiles, isValidMoleculeDFSCode
from dfs_transformer.utils import load_selfattn_wandb, load_selfattn_local
import os.path
import torch
import torch.nn as nn
import dfs_transformer as dfs


Using backend: pytorch


In [2]:
pretrained_model = "r2r-nofeat"
if os.path.isdir("../../wandb/artifacts/%s"%pretrained_model):
    bert, cfg = load_selfattn_local("../../wandb/artifacts/%s"%pretrained_model)
else:
    bert, cfg = load_selfattn_wandb(pretrained_model, wandb_dir="../../wandb")

In [3]:
smiles_list = ['CN=C=O', 'Cc1cn2c(CN(C)C(=O)c3ccc(F)cc3C)c(C)nc2s1', 'Cc1cc(F)ccc1C(=O)N(C)Cc1c(C)nc2scc(C)n12']

In [4]:
dlist = []
for smiles in smiles_list:
    data = dfs.smiles2graph(smiles)
    data.edge_features = data.edge_attr
    data.node_features = nn.functional.one_hot(data.z-1, num_classes=118).float()
    dlist += [data]

In [5]:
inputs, outputs = dfs.collate_BERT(dlist, mode="rnd2rnd", fraction_missing=0.15)

In [6]:
inputs['atm_from'].shape

torch.Size([25, 3, 118])

In [7]:
code_list = bert.fwd_code(inputs)

In [8]:
code_list[0]

[[0, 1, 6, 1, 7], [1, 2, 7, 0, 6], [0, 3, 6, 1, 8]]

In [9]:
code_list[1]

[[0, 1, 6, 2, 6],
 [1, 2, 6, 2, 6],
 [2, 3, 6, 2, 6],
 [3, 4, 6, 2, 6],
 [4, 5, 6, 2, 6],
 [5, 0, 6, 2, 6],
 [4, 6, 6, 0, 6],
 [6, 7, 6, 1, 8],
 [6, 8, 6, 0, 7],
 [8, 9, 7, 0, 6],
 [9, 10, 7, 0, 7],
 [10, 11, 6, 0, 6],
 [11, 12, 6, 2, 6],
 [12, 13, 6, 0, 6],
 [12, 14, 6, 2, 7],
 [14, 15, 8, 2, 7],
 [15, 16, 6, 2, 7],
 [16, 11, 7, 2, 6],
 [16, 17, 7, 2, 6],
 [17, 18, 6, 2, 6],
 [18, 19, 6, 2, 16],
 [19, 15, 16, 2, 6],
 [18, 20, 6, 0, 6],
 [3, 21, 6, 0, 6],
 [1, 22, 7, 0, 7]]

In [10]:
for code in code_list:
    print(isValidMoleculeDFSCode(code))

True
Can't kekulize mol.  Unkekulized atoms: 0 2 3 4 5

False
Explicit valence for atom # 18 N, 5, is greater than permitted
False


In [12]:
DFSCode2Smiles(code_list[1])

'Cc1cn2c(CNNC(=O)c3ccn(N)cc3C)c(C)oc2s1'