In [55]:
import pandas as pd 

from ogb.utils.features import (
    allowable_features, atom_to_feature_vector, bond_feature_vector_to_dict,
    bond_to_feature_vector, atom_feature_vector_to_dict
)
import torch
import rdkit
from rdkit import Chem
import numpy as np
import os

import re
import json
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw

In [66]:

def smiles2graph(smiles_string, with_amap=False):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """

    mol = Chem.MolFromSmiles(smiles_string)
    if with_amap:
        if len(mol.GetAtoms()) > 0:
            max_amap = max([atom.GetAtomMapNum() for atom in mol.GetAtoms()])
            for atom in mol.GetAtoms():
                if atom.GetAtomMapNum() == 0:
                    atom.SetAtomMapNum(max_amap + 1)
                    max_amap = max_amap + 1

            amap_idx = {
                atom.GetAtomMapNum(): atom.GetIdx()
                for atom in mol.GetAtoms()
            }
        else:
            amap_idx = dict()

    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        if len(atom_to_feature_vector(atom))!=9:
            print(len(atom_to_feature_vector(atom)), smiles_string)
        atom_features_list.append(atom_to_feature_vector(atom))

    num_atom_features = 9
    if len(atom_features_list) > 0:
        x = np.array(atom_features_list, dtype=np.int64)
    else:
        x = np.empty((0, num_atom_features), dtype=np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype=np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype=np.int64)

    else:   # mol has no bonds
        edge_index = np.empty((2, 0), dtype=np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype=np.int64)

    graph = dict()
    graph['edge_index'] = edge_index
    graph['edge_feat'] = edge_attr
    graph['node_feat'] = x
    graph['num_nodes'] = len(x)

    if with_amap:
        return graph, amap_idx
    else:
        return graph



def clear_map_number(smi):
    """Clear the atom mapping number of a SMILES sequence"""
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        if atom.HasProp('molAtomMapNumber'):
            atom.ClearProp('molAtomMapNumber')
    return canonical_smiles(Chem.MolToSmiles(mol))


def canonical_smiles(smi):
    """Canonicalize a SMILES without atom mapping"""
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return smi
    else:
        canonical_smi = Chem.MolToSmiles(mol)
        # print('>>', canonical_smi)
        if '.' in canonical_smi:
            canonical_smi_list = canonical_smi.split('.')
            canonical_smi_list = sorted(
                canonical_smi_list, key=lambda x: (len(x), x)
            )
            canonical_smi = '.'.join(canonical_smi_list)
        return canonical_smi


def cano_with_am(smi):
    mol = Chem.MolFromSmiles(smi)
    tmol = deepcopy(mol)
    for atom in mol.GetAtoms():
        if atom.HasProp('molAtomMapNumber'):
            atom.ClearProp('molAtomMapNumber')

    ranks = list(Chem.CanonicalRankAtoms(mol))
    root_atom = int(np.argmin(ranks))
    return Chem.MolToSmiles(tmol, rootedAtAtom=root_atom, canonical=True)


def remove_am_wo_cano(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        if atom.HasProp('molAtomMapNumber'):
            atom.ClearProp('molAtomMapNumber')

    return Chem.MolToSmiles(mol, canonical=False)


def find_all_amap(smi):
    return list(map(int, re.findall(r"(?<=:)\d+", smi)))


In [43]:
df=pd.read_csv("./raw_test.csv")


prod_list,reac_list=[],[]
for rrp in df['reactants>reagents>production']:
    reac,prod = rrp.strip().split('>>')
    reac_list.append(reac)
    prod_list.append(prod)

In [63]:
atom_features_list=[]
for atom in mol.GetAtoms():
    if len(atom_to_feature_vector(atom))!=9:
        print(len(atom_to_feature_vector(atom)))
    atom_features_list.append(atom_to_feature_vector(atom))

In [67]:
g_list=[]
for p in prod_list:
    g_list.append(smiles2graph(p))

In [70]:
g_list[0].keys()

dict_keys(['edge_index', 'edge_feat', 'node_feat', 'num_nodes'])

In [53]:
def is_shuffled(A, B):
    visited=set()
    for row_A in A:
        found = False
        for j,row_B in enumerate(B):
            if row_A == row_B and j not in visited:
                found = True
                print(j)
                visited.add(j)
                break
        if not found:
            return False
    return True,visited

In [50]:
len(atom_features_list)

27

In [51]:
atom_features_list

[[5, 0, 4, 5, 3, 0, 2, 0, 0],
 [5, 0, 3, 5, 0, 0, 1, 0, 0],
 [7, 0, 1, 5, 0, 0, 1, 0, 0],
 [5, 0, 3, 5, 0, 0, 1, 1, 1],
 [5, 0, 3, 5, 1, 0, 1, 1, 1],
 [5, 0, 3, 5, 1, 0, 1, 1, 1],
 [5, 0, 3, 5, 0, 0, 1, 1, 1],
 [6, 0, 3, 5, 1, 0, 1, 1, 1],
 [5, 0, 3, 5, 1, 0, 1, 1, 1],
 [5, 0, 3, 5, 1, 0, 1, 1, 1],
 [5, 0, 3, 5, 0, 0, 1, 1, 1],
 [5, 0, 3, 5, 1, 0, 1, 1, 1],
 [5, 0, 4, 5, 3, 0, 2, 0, 0],
 [5, 0, 4, 5, 0, 0, 2, 0, 0],
 [5, 0, 4, 5, 3, 0, 2, 0, 0],
 [5, 0, 4, 5, 3, 0, 2, 0, 0],
 [7, 0, 2, 5, 0, 0, 1, 0, 0],
 [5, 0, 3, 5, 0, 0, 1, 0, 0],
 [7, 0, 1, 5, 0, 0, 1, 0, 0],
 [7, 0, 2, 5, 0, 0, 1, 0, 0],
 [5, 0, 3, 5, 0, 0, 1, 0, 0],
 [7, 0, 1, 5, 0, 0, 1, 0, 0],
 [7, 0, 2, 5, 0, 0, 1, 0, 0],
 [5, 0, 4, 5, 0, 0, 2, 0, 0],
 [5, 0, 4, 5, 3, 0, 2, 0, 0],
 [5, 0, 4, 5, 3, 0, 2, 0, 0],
 [5, 0, 4, 5, 3, 0, 2, 0, 0]]

In [36]:
mol = Chem.MolFromSmiles(r)
for atom in mol.GetAtoms():
    if atom.HasProp('molAtomMapNumber'):
        atom.ClearProp('molAtomMapNumber')
Chem.MolToSmiles(mol)

'CC(=O)c1ccc2[nH]ccc2c1.CC(C)(C)OC(=O)OC(=O)OC(C)(C)C'

In [37]:


canonical_smi = Chem.MolToSmiles(mol)
# print('>>', canonical_smi)
if '.' in canonical_smi:
    canonical_smi_list = canonical_smi.split('.')
    canonical_smi_list = sorted(
        canonical_smi_list, key=lambda x: (len(x), x)
    )
    canonical_smi = '.'.join(canonical_smi_list)
canonical_smi

'CC(=O)c1ccc2[nH]ccc2c1.CC(C)(C)OC(=O)OC(=O)OC(C)(C)C'

In [24]:
mol = Chem.MolFromSmiles(p)
print(mol)
for atom in mol.GetAtoms():
    if atom.HasProp('molAtomMapNumber'):
        print('True')
        atom.ClearProp('molAtomMapNumber')
                      

<rdkit.Chem.rdchem.Mol object at 0x00000234263A6190>
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [26]:
Chem.MolToSmiles(mol)

'CC(=O)c1ccc2c(ccn2C(=O)OC(C)(C)C)c1'

In [27]:
p

'[CH3:1][C:2]([CH3:3])([CH3:4])[O:5][C:6](=[O:7])[n:15]1[c:14]2[cH:13][cH:12][c:11]([C:9]([CH3:8])=[O:10])[cH:19][c:18]2[cH:17][cH:16]1'

In [30]:
smiles2graph(prod_list[0])

{'edge_index': array([[ 0,  1,  1,  2,  1,  3,  1,  4,  4,  5,  5,  6,  5,  7,  7,  8,
          8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 12, 14, 11, 15, 15, 16,
         16, 17, 17, 18, 18,  7, 16,  8],
        [ 1,  0,  2,  1,  3,  1,  4,  1,  5,  4,  6,  5,  7,  5,  8,  7,
          9,  8, 10,  9, 11, 10, 12, 11, 13, 12, 14, 12, 15, 11, 16, 15,
         17, 16, 18, 17,  7, 18,  8, 16]], dtype=int64),
 'edge_feat': array([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 1],
        [1, 0, 1],
        [1, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [1, 0, 1],
        [1, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 

In [17]:
g.keys()

dict_keys(['edge_index', 'edge_feat', 'node_feat', 'num_nodes'])

In [29]:
g['node_feat']

array([[5, 0, 4, 5, 3, 0, 2, 0, 0],
       [5, 0, 3, 5, 0, 0, 1, 0, 0],
       [7, 0, 1, 5, 0, 0, 1, 0, 0],
       [5, 0, 3, 5, 0, 0, 1, 1, 1],
       [5, 0, 3, 5, 1, 0, 1, 1, 1],
       [5, 0, 3, 5, 1, 0, 1, 1, 1],
       [5, 0, 3, 5, 0, 0, 1, 1, 1],
       [5, 0, 3, 5, 0, 0, 1, 1, 1],
       [5, 0, 3, 5, 1, 0, 1, 1, 1],
       [5, 0, 3, 5, 1, 0, 1, 1, 1],
       [6, 0, 3, 5, 0, 0, 1, 1, 1],
       [5, 0, 3, 5, 0, 0, 1, 0, 0],
       [7, 0, 1, 5, 0, 0, 1, 0, 0],
       [7, 0, 2, 5, 0, 0, 1, 0, 0],
       [5, 0, 4, 5, 0, 0, 2, 0, 0],
       [5, 0, 4, 5, 3, 0, 2, 0, 0],
       [5, 0, 4, 5, 3, 0, 2, 0, 0],
       [5, 0, 4, 5, 3, 0, 2, 0, 0],
       [5, 0, 3, 5, 1, 0, 1, 1, 1]], dtype=int64)

In [18]:
g['edge_index']

array([[ 0,  1,  1,  2,  1,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,
         8,  9,  9, 10, 10, 11, 11, 12, 11, 13, 13, 14, 14, 15, 14, 16,
        14, 17,  7, 18, 18,  3, 10,  6],
       [ 1,  0,  2,  1,  3,  1,  4,  3,  5,  4,  6,  5,  7,  6,  8,  7,
         9,  8, 10,  9, 11, 10, 12, 11, 13, 11, 14, 13, 15, 14, 16, 14,
        17, 14, 18,  7,  3, 18,  6, 10]], dtype=int64)

In [10]:
smiles2graph(prod_list[0])

{'edge_index': array([[ 0,  1,  1,  2,  1,  3,  1,  4,  4,  5,  5,  6,  5,  7,  7,  8,
          8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 12, 14, 11, 15, 15, 16,
         16, 17, 17, 18, 18,  7, 16,  8],
        [ 1,  0,  2,  1,  3,  1,  4,  1,  5,  4,  6,  5,  7,  5,  8,  7,
          9,  8, 10,  9, 11, 10, 12, 11, 13, 12, 14, 12, 15, 11, 16, 15,
         17, 16, 18, 17,  7, 18,  8, 16]], dtype=int64),
 'edge_feat': array([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [0, 0, 1],
        [1, 0, 1],
        [1, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 0],
        [0, 0, 0],
        [1, 0, 1],
        [1, 0, 1],
        [3, 0, 1],
        [3, 0, 1],
        [3, 