In [87]:
import pandas as pd

file = r'C:\Users\chaochaoyan\Documents\retrosynthesis\retrosim\retrosim\data\data_split.csv'

data = pd.read_csv(file) 

prod_smiles = data['prod_smiles'].tolist()
split =  data['dataset'].tolist()
rxn = data['rxn_smiles'].tolist()

products = {'train': [], 'test': [], 'val': []}
for s, product in zip(split, prod_smiles):
    products[s].append(product)

print(len(products['train']), len(products['test']), len(products['val']))


40008 5007 5001


In [84]:
import os
import re
import sys
%matplotlib inline
import matplotlib.pyplot as plt

from collections import Counter
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import RDConfig
sys.path.append(os.path.join(RDConfig.RDContribDir, 'IFG'))
import ifg


# A utility function to find the subset of an element i 
def find_parent(parent,i): 
    if parent[i] == -1: 
        return i 
    else:
        return find_parent(parent, parent[i]) 
    
# A utility function to do union of two subsets 
def union(parent, x, y): 
    x_set = find_parent(parent, x) 
    y_set = find_parent(parent, y) 
    parent[x_set] = y_set 
    return parent

ring_counter = Counter()
token = Counter()
for k, smiles in enumerate(products['train'][:]):
    mol = Chem.MolFromSmiles(smiles)
    #[a.SetAtomMapNum(0) for a in mol.GetAtoms()]
    
    # Find rings
    highlights = []
    colors = {}
    rings = Chem.GetSymmSSSR(mol)
    for ring in rings:
        #print('Atom Index ring:', list(ring))
        highlights += list(ring)
        for r in ring:
            colors[r] = (1,0,0)
    
#     rings = mol.GetRingInfo().BondRings()
#     for ring in rings:
#         #print('BondRings:', ring)
#         submol = Chem.PathToSubmol(mol, ring)
#         ringsmi = Chem.MolToSmiles(submol)
#         ring_counter.update([ringsmi])

    rings = mol.GetRingInfo().AtomRings()
    for ring in rings:
        emol = Chem.EditableMol(mol)
        atoms_to_remove = [a.GetIdx() for a in mol.GetAtoms()]
        atoms_to_remove = [a for a in atoms_to_remove if a not in ring]
        # Indices are only changed if they are higher than the removed index.
        for index in sorted(atoms_to_remove, reverse=True):
            emol.RemoveAtom(index)
        smi = Chem.MolToSmiles(emol.GetMol(), isomericSmiles=True)

        token.update(smi)
        ringsmi = smi.lower()
        ringsmi = re.sub('[!h#=+-]', '', ringsmi)
        ringsmi = ringsmi.replace('[', '')
        ringsmi = ringsmi.replace(']', '')
        ringsmi = ringsmi.replace('2', '')
        ring_counter.update([ringsmi])
    
   
    continue
    
    img = Draw.MolToImage(mol, size=(300, 300), highlightAtoms=highlights)
    plt.imshow(img)
    plt.tight_layout()
    plt.axis('off')
    plt.show()
    
    continue
    
    # Find functional groups
    fgs = ifg.identify_functional_groups(mol)
    for fg in fgs:
        highlights += fg.atomIds
    for hl in highlights:
        colors[hl] = (0,0,1)
    
    atoms = [a.GetIdx() for a in mol.GetAtoms()]
    atoms_left = [idx for idx in atoms if idx not in highlights]
    
    # Find all connected components with union-find algorithm
    parent = {}
    for a in atoms_left:
        parent[a] = -1
    for bond in mol.GetBonds():
        beg = bond.GetBeginAtomIdx()
        end = bond.GetEndAtomIdx()
        if beg in parent and end in parent:
            parent = union(parent, beg, end)
    
    starts = set(atoms_left)
    for k, v in parent.items():
        if v in starts:
            starts.remove(v)
    
    components = []
    for s in starts:
        com = [s]
        while parent[s] != -1:
            s = parent[s]
            com.append(s)
        components.append(com)


In [85]:
print(len(ring_counter))
print(token)
print(ring_counter.keys())


263
Counter({'c': 489127, '1': 243290, 'C': 122344, 'n': 50040, 'N': 22353, 'O': 6279, 's': 4030, '[': 3617, ']': 3617, 'H': 3380, 'o': 2759, '=': 2047, 'S': 918, '-': 583, '+': 226, 'B': 124, 'e': 9, 'i': 5, 'P': 4, '2': 3, '#': 1})
dict_keys(['c1ccccc1', 'c1ccncc1', 'c1ccsc1', 'c1cscn1', 'c1cocn1', 'c1cnoc1', 'c1cnccn1', 'c1cnc1', 'c1ccnc1', 'c1cncnc1', 'c1cnnc1', 'c1cccc1', 'c1cncn1', 'c1ccocc1', 'c1coco1', 'c1cnnn1', 'c1csccnc1', 'c1coccn1', 'c1ncccnc1', 'c1cc1', 'c1ccoc1', 'c1ccncncnccccc1', 'c1cnncn1', 'c1ccnnc1', 'c1ncnn1', 'c1cnccnc1', 'c1coccc1', 'c1cnccc1', 'c1cscccn1', 'c1cocccn1', 'c1ncon1', 'c1ccc1', 'c1cncc1', 'c1cnncc1', 'c1ncnc1', 'c1cccncc1', 'c1nnnn1', 'c1cscco1', 'c1cocc1', 'c1ccocn1', 'b1ccco1', 'c1nnco1', 'c1cncccc1', 'c1csccs1', 'c1coc1', 'c1cscccc1', 'c1cococ1', 'c1csccn1', 'c1nncs1', 'b1occo1', 'c1coccnc1', 'c1co1', 'c1nccnc1', 'c1nocn1', 'c1nccs1', 'c1ccnco1', 'c1nccco1', 'c1cncccn1', 'c1cocco1', 'c1csnc1', 'c1ccnccc1', 'c1ncsn1', 'c1cccccc1', 'c1ccncn1', 'c1nc

In [86]:
print(ring_counter)

Counter({'c1ccccc1': 59364, 'c1ccncc1': 17488, 'c1ccnc1': 6540, 'c1cncnc1': 4724, 'c1cnccn1': 3562, 'c1cncn1': 3541, 'c1cnnc1': 3273, 'c1ccsc1': 2025, 'c1cc1': 1951, 'c1cscn1': 1932, 'c1cccc1': 1844, 'c1coccn1': 1593, 'c1ccoc1': 1351, 'c1ccocc1': 1009, 'c1cocn1': 799, 'c1ccnnc1': 753, 'c1cncc1': 700, 'c1ncnn1': 667, 'c1cnoc1': 626, 'c1cnc1': 535, 'c1coco1': 487, 'c1cnccc1': 439, 'c1ccc1': 423, 'c1cocc1': 350, 'c1nnnn1': 309, 'c1cnnn1': 302, 'c1cccncc1': 291, 'c1ncon1': 284, 'c1coccc1': 253, 'c1cnccnc1': 241, 'c1cccccc1': 190, 'c1cocco1': 140, 'c1nnco1': 132, 'c1csccn1': 127, 'c1co1': 125, 'c1ncncn1': 117, 'c1cnncn1': 103, 'c1cncccc1': 100, 'c1cococ1': 99, 'c1nncs1': 95, 'c1nccnc1': 94, 'c1cncccn1': 94, 'c1cnsc1': 92, 'b1occo1': 87, 'c1nccncc1': 85, 'c1cccnc1': 84, 'c1coccnc1': 82, 'c1coc1': 78, 'c1ccnccc1': 76, 'c1nccn1': 72, 'c1ncsn1': 68, 'c1ncccn1': 66, 'c1ncccc1': 66, 'c1nocc1': 64, 'c1cnsn1': 54, 'c1cncoc1': 54, 'c1ccscc1': 53, 'c1nccc1': 47, 'c1cnccoc1': 44, 'c1nnccc1': 43, 'c1nc