## Purpose

A quick notebook to show how to extract key structural information from molecule objects in RDKIt.

In [None]:
print('Installing required packages. This can take a while ...')
# We use a specific version of RDKit to ensure consistent results.
!conda install -q -y -c rdkit rdkit=2020_03_6
!pip -q install cairosvg==2.5.2
print('DONE.')

## Imports

In [None]:
from io import BytesIO

import numpy as np
from PIL import Image

from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import rdMolDraw2D

import lxml.etree as et
import cairosvg

from matplotlib import pyplot as plt

## Implementation

Below, the atoms and bonds in the molecule are shown with information about each atom:

* the atom's element symbol
* the index of the atom within the RDKiet molecule object
* the number of hydrogens connected to the atom
* the coordinates of the atom in the rendered output image
* the type of chirality for the atom (if applicable)
* whether the atom is part of a ring (shown as '⭕')

and its bonds:

* the atoms at each end of the bond (specified by index)
* bond strength
* the bond type (this seems only relevant to stereochemisty)
* whether the bond is part of a ring (shown as '⭕')

In [None]:
def svg_to_image(svg, convert_to_greyscale=True):
    svg_str = et.tostring(svg)
    # TODO: would prefer to convert SVG dirrectly to a numpy array.
    png = cairosvg.svg2png(bytestring=svg_str)
    image = np.array(Image.open(BytesIO(png)), dtype=np.float32)
    # Naive greyscale conversion.
    if convert_to_greyscale:
        image = image.mean(axis=-1)
    return image


def drawmol(mol, size, add_hs=False):
    if add_hs:
        mol = Chem.AddHs(mol)
    d = Draw.rdMolDraw2D.MolDraw2DSVG(2*size, size)
    AllChem.Compute2DCoords(mol)
    mol = rdMolDraw2D.PrepareMolForDrawing(mol) 
    d.DrawMolecule(mol)
    d.FinishDrawing()
    drawing =  d.GetDrawingText()
    svg = et.fromstring(drawing.encode('iso-8859-1'))
    img = 1 - svg_to_image(svg)/255
    height, width = img.shape
    xs, ys = [], []
    atoms = mol.GetAtoms()
    for iatom in range(mol.GetNumAtoms()):
        p = d.GetDrawCoords(iatom)
        x, y = p.x, p.y
        xs.append(x) ; ys.append(y)
        atom = atoms[iatom]
        num_h = atom.GetNumExplicitHs() + atom.GetNumImplicitHs()
        chiral_tag = atom.GetChiralTag()
        chiral_label = chiral_tag != Chem.rdchem.ChiralType.CHI_UNSPECIFIED and f' {chiral_tag}' or ''
        ring_atom = atom.IsInRing()
        print(f'{atom.GetSymbol()} [#{iatom}]: hydrogens={num_h}, coords=({x:0.1f}, {y:0.1f}){chiral_label}{ring_atom and " ⭕" or ""}')
        for ibond, bond in enumerate(atom.GetBonds()):
            stereo = bond.GetStereo()
            stereo_label = stereo != Chem.rdchem.BondStereo.STEREONONE and f' {stereo}' or ''
            bond_direction = bond.GetBondDir()
            bond_direction_label = Chem.rdchem.BondDir.NONE != bond_direction and f' {bond_direction} →' or ' →'
            in_ring_label = bond.IsInRing() and ' ⭕' or ''
            bond_type = bond.GetBondType()
            bond_strength = bond.GetBondTypeAsDouble()
            from_atom = bond.GetBeginAtomIdx()
            to_atom = bond.GetEndAtomIdx()
            print(f'\t[#{from_atom}] → {bond_type} ({bond_strength}){bond_direction_label} [#{to_atom}]{stereo_label}{in_ring_label}')
    display(mol)
    coords = np.array([ys, xs])
    coords = coords.round().astype(np.int32)
    coords[0] = np.clip(coords[0], 0, height - 1)
    coords[1] = np.clip(coords[1], 0, width - 1)
    atom_layer = np.zeros_like(img)
    atom_layer[coords[0], coords[1]] = 1
    def smear_layer(layer):
        layer = layer + np.roll(layer, 1, axis=0)
        layer = layer + np.roll(layer, -1, axis=0)
        layer = layer + np.roll(layer, -1, axis=1)
        layer = layer + np.roll(layer, 1, axis=1)
        return layer/layer.max()
    for _ in range(20):
        atom_layer = smear_layer(atom_layer)
    img_rgb = np.stack([atom_layer, 0.8*img, atom_layer], axis=-1)
    plt.figure(figsize=(14, 14)) ; plt.imshow(img_rgb, cmap='gray_r', interpolation='bilinear') ; plt.show()
    return 'Done.'


INCHI = 'InChI=1S/C19H18ClN3O5S/c1-8-11(12(22-28-8)9-6-4-5-7-10(9)20)15(24)21-13-16(25)23-14(18(26)27)19(2,3)29-17(13)23/h4-7,13-14,17H,1-3H3,(H,21,24)(H,26,27)/t13-,14+,17-/m1/s1'
mol = Chem.inchi.MolFromInchi(INCHI)
drawmol(mol, size=200)
