In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../../src')
from rdkit import Chem
from dfs_transformer.utils import DFSCode2Graph, Graph2Mol, isValid, Smiles2DFSCode, DFSCode2Smiles, isValidMoleculeDFSCode
from dfs_transformer.utils import load_selfattn_wandb, load_selfattn_local
import os.path
import torch
import torch.nn as nn
import dfs_transformer as dfs
import numpy as np

Using backend: pytorch


In [2]:
pretrained_model = "r2r-nofeat"
if os.path.isdir("../../wandb/artifacts/%s"%pretrained_model):
    bert, cfg = load_selfattn_local("../../wandb/artifacts/%s"%pretrained_model)
else:
    bert, cfg = load_selfattn_wandb(pretrained_model, wandb_dir="../../wandb")

In [3]:
smiles_list = 1*['CN=C=O'] + 1*['Cc1cn2c(CN(C)C(=O)c3ccc(F)cc3C)c(C)nc2s1'] + 1*['Cc1cc(F)ccc1C(=O)N(C)Cc1c(C)nc2scc(C)n12']

In [4]:
dlist = []
for smiles in smiles_list:
    data = dfs.smiles2graph(smiles)
    data.edge_features = data.edge_attr
    data.node_features = nn.functional.one_hot(data.z-1, num_classes=118).float()
    dlist += [data]

In [5]:
inputs, outputs = dfs.collate_BERT(dlist, mode="rnd2rnd", fraction_missing=0.15)

In [6]:
inputs['atm_from'].shape

torch.Size([25, 3, 118])

In [7]:
code_list = bert.fwd_code(inputs)

In [8]:
valid_list = []
for code in code_list:
    try:
        valid_list += [isValidMoleculeDFSCode(code)]
    except:
        valid_list += [False]
valid = np.asarray(valid_list)

In [9]:
valid.sum()/len(valid)

1.0

In [10]:
inputs['atm_from'].shape

torch.Size([25, 3, 118])

In [11]:
inputs = {'dfs_from': torch.ones(25, 10, dtype=torch.long)*(-1),
          'dfs_to': torch.ones(25, 10, dtype=torch.long)*(-1),
          'atm_from': torch.ones(25, 10, 118)*(-1),
          'atm_to':torch.ones(25, 10, 118)*(-1),
          'bnd':torch.ones(25, 10, 5)*(-1)}

In [12]:
inputs['dfs_from'][0] = 0
inputs['dfs_to'][0] = 1

In [13]:
bert.fwd_code_all(inputs)

[[[0, 1, 6, 0, 6],
  [1, 2, 6, 0, 6],
  [2, 3, 6, 0, 6],
  [3, 4, 6, 0, 6],
  [4, 5, 6, 0, 6],
  [5, 6, 6, 0, 6],
  [6, 7, 6, 0, 6],
  [6, 7, 6, 0, 6],
  [7, 8, 6, 0, 6],
  [8, 9, 6, 0, 6],
  [9, 10, 6, 0, 6],
  [9, 11, 6, 0, 6],
  [11, 12, 6, 0, 6],
  [12, 13, 6, 0, 6],
  [13, 14, 6, 0, 6],
  [14, 15, 6, 0, 6],
  [14, 15, 6, 0, 6],
  [15, 16, 6, 0, 6],
  [16, 18, 6, 0, 6],
  [17, 19, 6, 0, 6],
  [19, 20, 6, 0, 6],
  [19, 20, 6, 0, 6],
  [20, 21, 6, 0, 6],
  [0, 21, 6, 0, 6],
  [0, 23, 6, 0, 6]],
 [[0, 1, 6, 0, 6],
  [1, 2, 6, 0, 6],
  [2, 3, 6, 0, 6],
  [3, 4, 6, 0, 6],
  [4, 5, 6, 0, 6],
  [5, 6, 6, 0, 6],
  [6, 7, 6, 0, 6],
  [6, 7, 6, 0, 6],
  [7, 8, 6, 0, 6],
  [8, 9, 6, 0, 6],
  [8, 10, 6, 0, 6],
  [10, 11, 6, 0, 6],
  [11, 12, 6, 0, 6],
  [12, 13, 6, 0, 6],
  [13, 14, 6, 0, 6],
  [14, 15, 6, 0, 6],
  [14, 15, 6, 0, 6],
  [15, 16, 6, 0, 6],
  [16, 17, 6, 0, 6],
  [17, 18, 6, 0, 6],
  [18, 19, 6, 0, 6],
  [19, 20, 6, 0, 6],
  [19, 21, 6, 0, 6],
  [0, 21, 6, 0, 6],
  [0, 21, 6, 0, 

In [14]:
inputs

{'dfs_from': tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -