# Molecular graph generation with PyTorch and PyGeometric
> We use [GraphVAE](https://arxiv.org/abs/1802.03480) for molecular generation with one shot generation of a probabilistic graph with predefined maximum size. 

- toc: true 
- badges: true
- comments: false
- author: Anirudh Jain
- categories: [graph generation, pytorch, pygeometric, tutorial]

# Requirements

The following packages need to be installed:
- rdkit
- pytorch
- torch_geometric
- networkx

In [10]:
#collapse-hide

#Initial imports

import numpy as np
import torch
import matplotlib.pyplot as plt
from glob import glob
import tqdm
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole

# Introduction

We represent a molecule as graph $G = (\mathcal{X, A})$ using PyGeometric framework. Each molecule is represented by a feature matrix $\mathcal{X}$ and adjacency matrix $\mathcal{A}$. We use QM9 dataset from [MoleculeNet:A Benchmark for Molecular Machine Learning](https://arxiv.org/abs/1703.00564) implemented in `torch_geometric.datasets.QM9`. PyGeometric relies on rdkit to process the SMILES string and convert them into graphs.

We modify the data processing script in two ways:
- We strip hydrogen atoms from the molecules to keep only the heavy atoms
- We kekulize the molecules to convert aromatic rings to Kekule form
The modified script can be found [here](https://gist.github.com/sponde25/7dfa5492c21c007cf1e60a02dced1334)

After processing the dataset, we have a set of molecules with 4 heavy atoms (C, N, O, F) and 3 bond types (SINGLE, DOUBLE and TRIPLE) with maximum graph size of 9. 

The decoder outputs the graph as one-hot encoded vectors for atoms `[9 x 5]` and bonds `[9 x 4]`. The label 0 represents empty atom or edge.  

In [2]:
#Imports for data pre-processing

import torch_geometric
from qm9_modified import QM9
from torch_geometric.utils.convert import to_networkx
import networkx

In [3]:
# Setting up variables for the dataset

MAX_ATOM = 5 
MAX_EDGE = 4 
path = '/scratch/project_2002655/datasets/qm9_noH' # Change the path for your local directory structure
dataset = QM9(path)

# Store the max. graph size
MAX_N = -1
for data in dataset:
    if MAX_N < data.x.shape[0]: MAX_N = data.x.shape[0]
MAX_E = int(MAX_N * (MAX_N - 1))
print('MAX ATOMS: {}'.format(MAX_N))    # Maximum number of atoms in a graph in the dataset
print('MAX EDGE: {}'.format(MAX_E))     # Corresponding size of upper triangle adjacency matrix 

MAX ATOMS: 9
MAX EDGE: 72


`torch_geometric` stores the graph as `torch_geometric.data.Data` and we generate the one-hot representation of the graph $G$ as described above. For each graph $G$, we create a vector $\mathcal{X}$ as one-hot encoded for atom of dimension `[MAX_N x MAX_ATOM]` and vector bond of dimension `[MAX_E x MAX_EDGE]`.

![](../images/data_representation.png "A visualization of the graph, atom and edge representations")

In [4]:
# We create a matrix to map the index of the edge vector $\mathcal{A}$ to the upper triangular adjacency matrix.

index_array = torch.zeros([MAX_N, MAX_N], dtype=int)
idx = 0
for i in range(MAX_N):
    for j in range(MAX_N):
        if i < j:
            index_array[i, j] = idx
            idx+=1

print(index_array)

tensor([[ 0,  0,  1,  2,  3,  4,  5,  6,  7],
        [ 0,  0,  8,  9, 10, 11, 12, 13, 14],
        [ 0,  0,  0, 15, 16, 17, 18, 19, 20],
        [ 0,  0,  0,  0, 21, 22, 23, 24, 25],
        [ 0,  0,  0,  0,  0, 26, 27, 28, 29],
        [ 0,  0,  0,  0,  0,  0, 30, 31, 32],
        [ 0,  0,  0,  0,  0,  0,  0, 33, 34],
        [ 0,  0,  0,  0,  0,  0,  0,  0, 35],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0]])


We process the `torch_geometric.dataset` to generate matrices $\mathcal{X}$ and $\mathcal{A}$ which act as the ground truth for our decoder. We will also setup utility functions to convert between the vector representation $(\mathcal{X}, \mathcal{A})$ and `torch_geometric.data` representation $\mathcal{G}$. We use the following key for atoms and bonds:
```
    C: 1    SINGLE: 1
    N: 2    DOUBLE: 2
    O: 3    TRIPLE: 3
    F: 4
```
`0` is the placeholder label for empty entry.

In [7]:
# Initialize the labels with -1
edge_labels = torch.ones(len(dataset), MAX_E) * -1
atom_labels = torch.ones(len(dataset), MAX_N) * -1
idx = 0
for data in dataset:
    edge_attr = data.edge_attr      # One hot encoded bond labels
    edge_index = data.edge_index    # Bond indices as one hot adjacency list
    upper_index = edge_index[0] < edge_index[1] # Bond indices as upper triangular adjacency matrix
    _, edge_label = torch.max(edge_attr, dim=-1)# Bond labels from one hot vectors
    x = data.x[:, 1:5]              # One hot encoded atom labels
    _, atom_label = torch.max(x, dim=-1)        # Atom labels from one hot vectors
    # Expand the label vectors to size [MAX_N x MAX_ATOM] and [MAX_E x MAX_EDGE]
    atom_labels[idx][:len(atom_label)] = atom_label
    a0 = edge_index[0,upper_index]
    a1 = edge_index[1,upper_index]
    up_idx = index_array[a0, a1]
    edge_labels[idx][up_idx] = edge_label[upper_index].float()
    idx += 1

atom_labels += 1
edge_labels += 1

Now that we have the dataset represented as $(\mathcal{X}, \mathcal{A})$ let's plot some graphs to visually check if the molecules are as we expected. We use `rdkit` to plot the molecules which does a lot of having lifting for us. The function `graphToMol` takes in the vectors $(\mathcal{X}, \mathcal{A})$ and returns an object of type `rdkit.Mol`. We can also obtain visualizations for the graphs $\mathcal{G}$ by using `torch_geometric.utils.convert.to_networkx` and then ploting the netowrkx graph. But `rdkit` plots the molecules in a canonical orientation and is built to minimize intramolecular clashes, i.e. to maximize the clarity of the drawing.

In [8]:
#collapse-hide

def get_index(index, index_array):
    for i in range(9):
        for j in range(9):
            if i < j:
                if(index_array[i, j] == index):
                    return [i, j]

def graphToMol(atom, edge):

    possible_atoms = {
        0: 'H',
        1: 'C',
        2: 'N',
        3: 'O',
        4: 'F'
    }
    possible_edges = {
        1: Chem.rdchem.BondType.SINGLE, 
        2: Chem.rdchem.BondType.DOUBLE, 
        3: Chem.rdchem.BondType.TRIPLE
    }
    max_n = 9
    
    mol = Chem.RWMol()
    rem_idxs = []
    for a in atom:    
        atom_symbol = possible_atoms[a.item()]
        mol.AddAtom(Chem.Atom(atom_symbol))
    for a in mol.GetAtoms():
        if a.GetAtomicNum() == 1:
            rem_idxs.append(a.GetIdx())
    for i, e in enumerate(edge):
        e = e.item()
        if e != 0:
            a0, a1 = get_index(i, index_array)
            if a0 in rem_idxs or a1 in rem_idxs:
                return None
            bond_type = possible_edges[e]
            mol.AddBond(a0, a1, order=bond_type)
    rem_idxs.sort(reverse=True)
    for i in rem_idxs:
        mol.RemoveAtom(i)
    return mol

In [9]:
# We pick 5 random molecules to plot
mols = []
for i in np.random.randint(0, len(atom_labels), size=5):
    mols.append(graphToMol(atom_labels[i], edge_labels[i]))
    

35336
92593
113203
101499
47505


# Model

# Results

# Conclusion