## Purpose

This notebook uses RDKit to generate images and bounding boxes for given InChI strings. It's based on [this notebook](https://www.kaggle.com/stainsby/extract-molecule-structure-with-rdkit) by [@stainsby](https://www.kaggle.com/stainsby) and aims to accomplish step 2 of [this outline](https://www.kaggle.com/c/bms-molecular-translation/discussion/229984) by [@CPMP](https://www.kaggle.com/cpmpml).

In [None]:
print('Installing required packages. This can take a while ...')
# We use a specific version of RDKit to ensure consistent results.
!conda install -q -y -c rdkit rdkit=2020_03_6
!pip -q install cairosvg==2.5.2
print('DONE.')

## Imports

In [None]:
from io import BytesIO

import numpy as np
from PIL import Image

from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D

import lxml.etree as et
import cairosvg

from matplotlib import patches, pyplot as plt

## Implementation

We will be extracting the bbox center coordinates, height and width as shown below.

![hi](https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/master/img/cs.PNG)

In [None]:
def svg_to_image(svg, convert_to_greyscale=True):
    svg_str = et.tostring(svg)
    # TODO: would prefer to convert SVG dirrectly to a numpy array.
    png = cairosvg.svg2png(bytestring=svg_str)
    image = np.array(Image.open(BytesIO(png)), dtype=np.float32)
    # Naive greyscale conversion.
    if convert_to_greyscale:
        image = image.mean(axis=-1)
    return image


def drawmol(mol, size, add_hs=False):
    if add_hs:
        mol = Chem.AddHs(mol)
    h_size = 2*size
    v_size = size
    atom_height = 15 / v_size
    atom_width = 15 / h_size
    min_bond_height, min_bond_width = 20 / v_size, 20 / h_size
    
    d = rdMolDraw2D.MolDraw2DSVG(2*size, size)
    AllChem.Compute2DCoords(mol)
    mol = rdMolDraw2D.PrepareMolForDrawing(mol)
    d.DrawMolecule(mol)
    d.FinishDrawing()
    drawing =  d.GetDrawingText()
    svg = et.fromstring(drawing.encode('iso-8859-1'))
    img = 1 - svg_to_image(svg)/255
    height, width = img.shape
    bbox_info = []
    atoms = mol.GetAtoms()
    for iatom in range(mol.GetNumAtoms()):
        p = d.GetDrawCoords(iatom)
        x, y = p.x, p.y
        atom = atoms[iatom]
        num_h = atom.GetNumExplicitHs() + atom.GetNumImplicitHs()
        chiral_tag = atom.GetChiralTag()
        chiral_label = chiral_tag != Chem.rdchem.ChiralType.CHI_UNSPECIFIED and f' {chiral_tag}' or ''
        ring_atom = atom.IsInRing()
        bonds = []
        for ibond, bond in enumerate(atom.GetBonds()):
            stereo = bond.GetStereo()
            stereo_label = stereo != Chem.rdchem.BondStereo.STEREONONE and f' {stereo}' or ''
            bond_direction = bond.GetBondDir()
            bond_direction_label = Chem.rdchem.BondDir.NONE != bond_direction and f' {bond_direction} →' or ' →'
            in_ring_label = bond.IsInRing() and ' ⭕' or ''
            bond_type = bond.GetBondType()
            bond_strength = bond.GetBondTypeAsDouble()
            from_atom = bond.GetBeginAtomIdx()
            to_atom = bond.GetEndAtomIdx()
            bonds.append({
                'from': from_atom,
                'to': to_atom,
                'type': bond_type
            })
        bbox_info.append({
            'x': p.x / h_size,
            'y': p.y / v_size,
            'width': atom_width,
            'height': atom_height,
            'id': iatom,
            'name': atom.GetSymbol(),
            'bonds': bonds
        })

    
    # build bond bboxes
    for atom in bbox_info:
        for bond in atom['bonds']:
            if bond['from'] == atom['id']:
                to_atom = bbox_info[bond['to']]
                bond['x'] = (min(atom['x'], to_atom['x']) + abs(atom['x'] - to_atom['x']) / 2)
                bond['y'] = (min(atom['y'], to_atom['y']) + abs(atom['y'] - to_atom['y']) / 2)
                bond['width'] = max(abs(atom['x'] - to_atom['x']), min_bond_width)
                bond['height'] = max(abs(atom['y'] - to_atom['y']), min_bond_height)
    
    
    def get_matplotlib_coors(box):
        x = (box['x'] - box['width'] / 2) * h_size
        y = (box['y'] - box['height'] / 2) * v_size
        return (x, y)
    
    fig, ax = plt.subplots()
    fig.set_figheight(15)
    fig.set_figwidth(15)
    ax.imshow(img, cmap='gray_r', interpolation='bilinear')
    for atom in bbox_info:
        # ax.text(atom['x'], atom['y'], f"{atom['name']}: {atom['id']}")
        ax.add_patch(patches.Rectangle(get_matplotlib_coors(atom), atom['width'] * h_size, atom['height'] * v_size, fill=False, edgecolor='red'))
        for bond in atom['bonds']:
            if bond['from'] == atom['id']:
                ax.add_patch(patches.Rectangle(get_matplotlib_coors(bond), bond['width'] * h_size, bond['height'] * v_size, fill=False, edgecolor='blue'))
    plt.show()
    return bbox_info

In [None]:
INCHI = [
    'InChI=1S/C19H18ClN3O5S/c1-8-11(12(22-28-8)9-6-4-5-7-10(9)20)15(24)21-13-16(25)23-14(18(26)27)19(2,3)29-17(13)23/h4-7,13-14,17H,1-3H3,(H,21,24)(H,26,27)/t13-,14+,17-/m1/s1',
    "InChI=1S/C31H25NO3/c1-34-31-17-9-23(8-16-30(33)25-10-13-28(14-11-25)32-18-4-5-19-32)20-27(31)22-35-29-15-12-24-6-2-3-7-26(24)21-29/h2-21H,22H2,1H3/b16-8+",
    "InChI=1S/C18H10Br2N2O4S/c19-11-1-3-13-9(5-11)7-15(17(23)21-13)27(25,26)16-8-10-6-12(20)2-4-14(10)22-18(16)24/h1-8H,(H,21,23)(H,22,24)", 
    "InChI=1S/C20H24N2O4S/c1-10-5-6-12-15(7-10)27-20-16(12)19(23)21-18(22-20)11-8-13(24-2)17(26-4)14(9-11)25-3/h8-10,18,22H,5-7H2,1-4H3,(H,21,23)/t10-,18-/m1/s1", 
    "InChI=1S/C36H39FN2O7/c1-2-43-34(35(40)41)26-27-10-16-30(17-11-27)45-25-23-39(22-6-7-24-44-31-18-12-28(37)13-19-31)36(42)38-29-14-20-33(21-15-29)46-32-8-4-3-5-9-32/h3-5,8-21,34H,2,6-7,22-26H2,1H3,(H,38,42)(H,40,41)"
]

for _inch in INCHI:
    mol = Chem.inchi.MolFromInchi(_inch)
    bbox_info = drawmol(mol, size=200)
    # display(bbox_info)


I'll be using this bbox info to train an object detection model in [detectron2](https://github.com/facebookresearch/detectron2)