In [71]:
import numpy as np
import pandas as pd
import sys
from rdkit import Chem
sys.path.append('../')
from bms.chemistry import _expand_functional_group, _verify_chirality
from bms.constants import RGROUP_SYMBOLS, PLACEHOLDER_ATOMS, SUBSTITUTIONS, ABBREVIATIONS
from tqdm import tqdm

In [76]:
EXPECTED_BONDS = {
    "C": 4,
    "[C@]": 4,
    "[C@@]": 4,
    "[C@H]": 3,
    "[C@@H]": 3,
    "O": 2,
    "[O-]": 1,
    "N": 3,
    "[N+]": 4,
    "[H]": 1
}

BondTypeMap = {
    Chem.BondType.SINGLE: 1.0,
    Chem.BondType.DOUBLE: 2.0,
    Chem.BondType.TRIPLE: 3.0,
    Chem.BondType.AROMATIC: 1.5,
    Chem.BondType.UNSPECIFIED: 0.0
}

In [77]:
def _convert_graph_to_smiles(coords, symbols, edges, edges_prob, debug=False):
    mol = Chem.RWMol()
    n = len(symbols)
    ids = []
    symbol_to_placeholder = {}
    mappings = []
    for i in range(n):
        symbol = symbols[i]
        if symbol[0] == '[':
            symbol = symbol[1:-1]
        if symbol in RGROUP_SYMBOLS:
            atom = Chem.Atom("*")
            atom.SetIsotope(RGROUP_SYMBOLS.index(symbol))
            idx = mol.AddAtom(atom)
        elif symbol in ABBREVIATIONS:
            if symbol not in symbol_to_placeholder:
                j = len(symbol_to_placeholder)
                assert j < len(PLACEHOLDER_ATOMS), "Not enough placeholders"
                placeholder = PLACEHOLDER_ATOMS[j]
                symbol_to_placeholder[symbol] = placeholder
            else:
                placeholder = symbol_to_placeholder[symbol]
            sub = ABBREVIATIONS[symbol]
            mappings.append((placeholder, sub.smiles))
            atom = Chem.Atom(placeholder)
            idx = mol.AddAtom(atom)
        else:
            symbol = symbols[i]
            try:
                atom = Chem.AtomFromSmiles(symbol)
                atom.SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
                idx = mol.AddAtom(atom)
            except:
                atom = Chem.Atom("*")
                idx = mol.AddAtom(atom)
        assert idx == i
        ids.append(idx)

    has_chirality = False

    for i in range(n):
        for j in range(i + 1, n):
            if edges[i][j] == 1:
                mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
            elif edges[i][j] == 2:
                mol.AddBond(ids[i], ids[j], Chem.BondType.DOUBLE)
            elif edges[i][j] == 3:
                mol.AddBond(ids[i], ids[j], Chem.BondType.TRIPLE)
            elif edges[i][j] == 4:
                mol.AddBond(ids[i], ids[j], Chem.BondType.AROMATIC)
            elif edges[i][j] == 5:
                mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
                mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINDASH)
                has_chirality = True
            elif edges[i][j] == 6:
                mol.AddBond(ids[i], ids[j], Chem.BondType.SINGLE)
                mol.GetBondBetweenAtoms(ids[i], ids[j]).SetBondDir(Chem.BondDir.BEGINWEDGE)
                has_chirality = True
    
    for i in range(n):
        symbol = symbols[i]
        if not symbol in EXPECTED_BONDS:
            continue
        
        atom = mol.GetAtomWithIdx(i)
        bond_orders = [BondTypeMap[bond.GetBondType()] for bond in atom.GetBonds()]
        if not all(o == 1 for o in bond_orders):
            continue                        # only override if all bonds are single
        
        expected_bonds = EXPECTED_BONDS[symbol]
        if sum(bond_orders) <= expected_bonds:
            continue                        # rdkit should be able to add H's
        
        neighbors = [bond.GetOtherAtomIdx(i) for bond in atom.GetBonds()]
        bond_probs = {j: edges_prob[i][j][edges[i][j]] for j in neighbors}
        sorted_probs = sorted(bond_probs.items(), key=lambda _tup: _tup[1], reverse=True)
        print(f"symbol: {symbol}, sorted probs: {sorted_probs}")
        
        for j, prob in sorted_probs[expected_bonds:]:
            mol.RemoveBond(i, j)
            edges[i][j] = 0                 # clean up and override edges
            edges[j][i] = 0
        
    pred_smiles = '<invalid>'

    try:
        if has_chirality:
            mol = _verify_chirality(mol, coords, symbols, edges, debug)
        # pred_smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
        pred_smiles = _expand_functional_group(mol, mappings)
        success = True
    except Exception as e:
        if debug:
            raise e
        success = False

    return pred_smiles, success

In [41]:
fn = "../prediction_CLEF.csv"
df = pd.read_csv(fn)
df[:2]

Unnamed: 0,image_id,SMILES,node_coords,node_symbols,edges_prob,edges,graph_SMILES,post_SMILES
0,US20070117785A1_p0034_x0977_y2564_c00072,CC(=O)N1CCC2=C(C1)C(C1=CC=C(C(F)(F)F)C=C1)=NN2...,"[[0.746,0.921],[0.698,0.889],[0.651,1.0],[0.68...","[""C"",""C"",""O"",""N"",""C"",""C"",""C"",""C"",""C"",""C"",""C"",""...","[[[1.0, 2.8564711485046246e-08, 7.287724433256...","[[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...",CC(=O)N1CCC2=C(C1)C(C1=CC=C(C(F)(F)F)C=C1)=NN2...,CC(=O)N1CCc2c(c(-c3ccc(C(F)(F)F)cc3)nn2CC(O)CN...
1,US20050004369A1_p0017_x1399_y0692_c00023,C1C=CC2N(C3C=CC(C4C=CC(CO)=CC=4)=CC=3)C=CC=2C=1,"[[0.0,0.19],[0.079,0.032],[0.175,0.19],[0.175,...","[""C"",""C"",""C"",""C"",""N"",""C"",""C"",""C"",""C"",""C"",""C"",""...","[[[0.9999997615814209, 1.8590182548905432e-07,...","[[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...",OCC1=CC=C(C2=CC=C(N3C=CC4=C3C=CC=C4)C=C2)C=C1,OCc1ccc(-c2ccc(-n3ccc4ccccc43)cc2)cc1


In [78]:
for i in tqdm(range(991)):
# for i in [0]:
    coords = np.asarray(eval(df.at[i, "node_coords"]))
    symbols = eval(df.at[i, "node_symbols"])
    edges_prob = np.asarray(eval(df.at[i, "edges_prob"]))
    edges = np.asarray(eval(df.at[i, "edges"]))
    _convert_graph_to_smiles(coords, symbols, edges, edges_prob, debug=False)

 38%|███▊      | 380/991 [00:04<00:07, 76.55it/s]

symbol: C, sorted probs: [(13, 0.9999995231628418), (40, 0.9999995231628418), (11, 0.9999985694885254), (81, 0.9032085537910461), (52, 0.5046385526657104)]
symbol: N, sorted probs: [(19, 1.0), (21, 1.0), (31, 0.9999997615814209), (72, 0.5200088620185852)]
symbol: N, sorted probs: [(45, 0.9999994039535522), (50, 0.9999992847442627), (53, 0.9999922513961792), (4, 0.8661325573921204)]


 51%|█████     | 506/991 [00:05<00:06, 79.89it/s]

symbol: N, sorted probs: [(42, 0.9999980926513672), (5, 0.9999947547912598), (3, 0.9999946355819702), (51, 0.8681014776229858)]
symbol: C, sorted probs: [(8, 0.9999994039535522), (41, 0.9999988079071045), (5, 0.999967098236084), (55, 0.9989981055259705), (53, 0.8920459151268005)]
symbol: O, sorted probs: [(38, 0.9999997615814209), (34, 0.9999991655349731), (32, 0.9999947547912598)]
symbol: N, sorted probs: [(26, 0.9999555349349976), (28, 0.9999034404754639), (29, 0.9994499087333679), (30, 0.9975269436836243)]


 60%|█████▉    | 593/991 [00:06<00:04, 89.26it/s]

symbol: O, sorted probs: [(30, 1.0), (32, 0.9999997615814209), (36, 0.9999934434890747)]
symbol: O, sorted probs: [(4, 0.999998927116394), (41, 0.9999988079071045), (2, 0.9999923706054688), (46, 0.9995842576026917)]
symbol: C, sorted probs: [(7, 0.9999997615814209), (40, 0.9999972581863403), (48, 0.9999912977218628), (4, 0.9999557733535767), (46, 0.9966691136360168)]
symbol: N, sorted probs: [(21, 0.9997959733009338), (23, 0.9996703863143921), (24, 0.9992868304252625), (25, 0.9536213874816895)]


 76%|███████▌  | 753/991 [00:08<00:02, 80.36it/s]

symbol: C, sorted probs: [(12, 0.9999996423721313), (45, 0.9999986886978149), (3, 0.9999972581863403), (9, 0.9999576807022095), (46, 0.9997028708457947)]
symbol: N, sorted probs: [(26, 0.9999725818634033), (28, 0.9998464584350586), (29, 0.999377429485321), (30, 0.9976274371147156)]


 87%|████████▋ | 860/991 [00:09<00:01, 86.46it/s]

symbol: O, sorted probs: [(4, 0.9999908208847046), (2, 0.9999904632568359), (40, 0.9999839067459106), (45, 0.9977172613143921)]
symbol: N, sorted probs: [(20, 0.99991774559021), (22, 0.9997897744178772), (24, 0.9994865655899048), (23, 0.9991710186004639)]


100%|██████████| 991/991 [00:11<00:00, 87.07it/s]
