<h1>Generate InChI sequences with LSTMs</h1>

<p>Generate InChI sequences with an RNN trained on the ChEMBL smiles Dataset, downloaded from <a href='https://ftp.ebi.ac.uk/pub/databases/chembl/ChEMBLdb/latest/'>here</a> (file <code>chembl_28_chemreps.txt.gz</code>).</p>
<p>Based on <a href='https://blog.bayeslabs.co/2019/07/04/Generating-Molecules-using-Char-RNN-in-Pytorch.html'>https://blog.bayeslabs.co/2019/07/04/Generating-Molecules-using-Char-RNN-in-Pytorch.html</a>. The links in the article to the complete code are broken, so I've also resorted to this <a href='https://github.com/kevaday/pytorch-char-rnn'>code</a> and this <a href='https://www.youtube.com/watch?v=bbvr-2hY4mE'>video</a>.</p>

Playing a bit with the hyperparameters (batch size of 512, 4 layers of LSTM, hidden size of 512, embedding dim of 16, ...), I've been able to achieve a loss of around 0.15 and an average score (Levenshtein distance between the target and predicted InChI sequences) of around 25.

In [None]:
# Imports

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import LogSoftmax, NLLLoss, Softmax, CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
import random
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import Levenshtein

In [None]:
# Get and uncompress the data

!rm chembl_28_*
!wget --no-check-certificate https://ftp.ebi.ac.uk/pub/databases/chembl/ChEMBLdb/latest/chembl_28_chemreps.txt.gz
!gunzip chembl_28_chemreps.txt.gz
!ls

In [None]:
# Some constants

INPUT_DIR = '.' # where the data lives

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_COLUMN = 'standard_inchi'

In [None]:
%%time
chemreps = pd.read_csv(os.path.join(INPUT_DIR, 'chembl_28_chemreps.txt'), delimiter='\t')
chemreps.head()

In [None]:
# There are some nulls in the standard_inchi column. Let's drop them
chemreps = chemreps.dropna()
chemreps.isnull().sum()

For this task, we're only interested in the column <code>standard_inchi</code>. Let's create dictionaries for all the characters in this column's values, to map each character to an integer value and back.

In [None]:
%%time
characters = sorted(list(set(''.join(chemreps[DATA_COLUMN]))))
num_characters = len(characters) + 3   # pad, start and end of sequence characters
int2char = dict(enumerate(characters, 3))
int2char[0] = ''  # Pad character
int2char[1] = '[' # Start of sequence character
int2char[2] = ']' # End of sequence character
char2int = {ch:ii for ii,ch in int2char.items()}
PAD = 0
SOS = char2int['['] # Start of sequence value
EOS = char2int[']'] # End of sequence value
print(f'there are {num_characters} characters:\n {char2int}')

In [None]:
# Let's check the distribution of the lengths

lengths = chemreps[DATA_COLUMN].apply(lambda x: len(x))
percentiles = [50, 75, 90, 95, 96, 97, 98, 99]
print(f'minumum length={lengths.min()} maximum length={lengths.max()}')
print(np.percentile(lengths, percentiles))

In [None]:
# Let's keep only sequences of 300 or less characters
SEQ_LENGTH = 403  # Added 3 for PAD, SOS and EOS characters

# Subtract 2 because of the SOS and EOS characters
data = chemreps.loc[lengths <= SEQ_LENGTH-2, DATA_COLUMN]

<h2>Dataset</h2>

In [None]:
class InChIDataset(Dataset):
    '''
    Dataset that generates an array of fixed length of tokens (int) from InChI strings,
    prefixing and postfixing them with the SOS (start of sequence) and EOS (end of sequence)
    special characters and padding with the special character PAD till a fixed length if
    the sequence is smaller.
    '''
    
    def __init__(self, data, seq_length, ctoi):
        '''
        Initializes the dataset
        Args:
            data: an iterable of InChI strings
            seq_length: the seq_length of the resulting array of tokens
            ctoi: a map to converts InChI characters to tokens
        '''
        
        super(Dataset).__init__()
        
        x = np.full((len(data), seq_length), PAD, dtype=np.int8) 
        y = np.full((len(data), seq_length), PAD, dtype=np.int8) 
        
        x[:,0] = SOS
        
        for i in range(len(data)):
            last_ch = min(seq_length - 2, len(data[i])) 
            x[i,1:last_ch+1] = [ctoi[ch] for ch in data[i][:last_ch]]
            x[i,last_ch+1] = EOS
        
        self.x = x
        y[:,:-1] = x[:,1:] # shift 1 place to the left
        self.y = y
    
    def __len__(self):
        '''
        Returns:
            The number of elements in the dataset
        '''
        return self.x.shape[0]
    
    def __getitem__(self, idx):
        '''
        Args:
            The index of the dataset element to return
        Returns:
            The x and y tensors for element i of the dataset
        '''
        return torch.from_numpy(self.x[idx]).long(), torch.from_numpy(self.y[idx]).long()
    


<h2>Model</h2>

In [None]:
class InChIRNN(nn.Module):
    
    def __init__(self, num_characters, embedding_dim, hidden_size, output_size, num_layers=1, dropout=0):
        
        super().__init__()
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        
        self.embed = nn.Embedding(num_characters, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x, hidden):
    
        x_e = self.embed(x)
        out, hidden = self.rnn(x_e, hidden)
        
        # Stack up LSTM outputs using view
        out = out.contiguous().view(-1, self.hidden_size)     
        out = self.fc(out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        return hidden.to(DEVICE), cell.to(DEVICE)
    
    

<h2>Some utility functions</h2>

In [None]:
def seq2array(seq):
    '''
    Convert an InChI sequence of characters to an array of integers, including
    the SOS and EOS special tokens
    '''
    seq_length = len(seq) + 2
    array = np.zeros((1,seq_length), dtype=np.int8)
    array[0,0] = SOS
    array[0,1:-1] = [char2int[ch] for ch in seq]
    array[0,-1] = EOS
    return array

def array2seq(array):
    '''
    Convert a sequence of integer tokens to an InChI sequence. The sequence of integers
    is expected to contain the special tokens SOS and EOS
    '''
    array = array.squeeze(0)
    seq = [int2char[i] for i in array]
    return ''.join(seq[1:-1])

def score(y, y_hat):
    '''
    Calculate the average Levenshtein score for the distance between y and y_hat
    Args:
        y: True labels (dimension batch_size * seq_length) of tokens (int)
        y_hat: Predicted labels (dimension batch_size * seq_length) of tokens (int)
    Returns:
        Average levenshtein distance between the sequences in y and y_hat
    '''
    
    batch_size = y.shape[0]
    
    y_str = []
    y_hat_str = []
    
    for i in range(batch_size):
        y_str.append(''.join([int2char[j] for j in y[i]]))
        y_hat_str.append(''.join([int2char[j] for j in y_hat[i]]))
        
    #print(len(y_str), type(y_str[0]), y_str[0])  
    sum_levenshtein = 0
    for i in range(batch_size):
        sum_levenshtein += Levenshtein.distance(y_str[i], y_hat_str[i])
    
    return sum_levenshtein / batch_size

<h2>Training</h2>

In [None]:
# Hyperparameters

EPOCHS = 2
LR = 1e-3
BATCH_SIZE = 512
EMBEDDING_DIM = 16
HIDDEN_SIZE = 256
CLIPPING = 2
NUM_LAYERS = 4
DROPOUT = 0.1

PRINT_EVERY = 100

In [None]:
# Split data into training and evaluation 

train, val = train_test_split(data, test_size=0.20)
print(f'train: {train.shape}  validation: {val.shape}')

In [None]:
# Create datasets and data loaders

train_dataset = InChIDataset(train.values, SEQ_LENGTH, char2int)
val_dataset = InChIDataset(val.values, SEQ_LENGTH, char2int)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Instantiate model, optimizer and loss function

model = InChIRNN(num_characters, EMBEDDING_DIM, HIDDEN_SIZE, num_characters, 
                 num_layers=NUM_LAYERS, dropout=DROPOUT)
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()

In [None]:
# Train and validation loop

lr = LR

model.to(DEVICE)

step = 0
for epoch in range(EPOCHS):
    print(f'\n----------------------- epoch {epoch+1}/{EPOCHS} ------------------------')
    model.train()
    running_loss = 0
    for x, y in tqdm(train_dataloader, 
                     desc='Training loop: ',
                     total=len(train_dataloader.dataset)//BATCH_SIZE):
        
        x, y = x.to(DEVICE), y.to(DEVICE)
    
        h0,c0 = model.init_hidden(x.shape[0])
        output, _ = model(x, (h0,c0))
        
        # Calculate cross-entropy loss. The average over all the samples in the batch is returned
        # (reduction parameter, with default value 'mean')
        loss = criterion(output.view(-1, num_characters), y.view(-1))
        current_loss = loss.item() 
        running_loss += current_loss 
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPPING)
        optimizer.step()
              
        if step % PRINT_EVERY == 0:
            print(f'step:{step} loss:{current_loss}')
            
        step += 1
        
        
    
    else:  # end of epoch
        
        
        # Print training loss for the epoch
        print(f'*** training loss = {running_loss/(int(np.ceil(len(train) // BATCH_SIZE)))}')
        
        # Validation loop
        
        model.eval()
        validation_loss = 0
        sum_levenshtein = 0     # Mean levenshtein distance for each batch
       
        with torch.no_grad():
            
            
            for x, y in tqdm(val_dataloader, desc='Validation loop: ', 
                             total=len(val_dataloader.dataset)//BATCH_SIZE):
        
                x, y = x.to(DEVICE), y.to(DEVICE)
    
                h0,c0 = model.init_hidden(x.shape[0])
                output, _ = model(x, (h0,c0))
        
                # Calculate cross-entropy loss. The average over all the samples in the batch is returned
                loss = criterion(output.view(-1, num_characters), y.view(-1))
                current_loss = loss.item() 
                validation_loss += current_loss 
                
                # Calculate score
                y_hat = torch.argmax(output.view(-1, SEQ_LENGTH, num_characters), dim=-1)
                sum_levenshtein += score(y.cpu().numpy(), y_hat.cpu().numpy())
                 
            
        num_batches = int(np.ceil(len(val_dataloader.dataset) // BATCH_SIZE))
        print(f'*** validation loss = {validation_loss/num_batches} ')
        print(f'*** validation score = {sum_levenshtein/num_batches}')
        
        
        # Change learning rate every N epochs
        
        if epoch % 1 == 0:
            lr = lr / 5
            print(f'\nLearning rate changed lr={lr}')
            for g in optimizer.param_groups:
                g['lr'] = lr

In [None]:
# Save the last model parameters

torch.save(model.state_dict(), 'inchi.pth')

<h2>Generate some InChI strings</h2>

In [None]:
def generate(model, T=1):
    model.eval()
    sequence = []
    i = 0
    hidden = model.init_hidden(1)
    char_idx = SOS
    while (char_idx != EOS) & (i < SEQ_LENGTH):
        x = np.array([char_idx]).reshape(1,1)
        x = torch.from_numpy(x).long().to(DEVICE)
        output, hidden = model(x, hidden)
        
        probs = F.softmax(output/T, dim=1).squeeze()
        char_idx = torch.multinomial(probs, 1).cpu().item()
        char = int2char[char_idx]
        if char_idx != EOS:
            sequence.append(char)
        i += 1
        
    return ''.join(sequence)

In [None]:
# Let's generate some sequences
for i in range(10):
    print(generate(model))
    print()