In [None]:
# If on google colab, uncomment the following lines
# !wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
# !chmod +x Miniconda3-latest-Linux-x86_64.sh
# !time bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
# !time conda install -q -y -c conda-forge rdkit

In [None]:
# This script features a simple LSTM that can autoregressively complete SMILES.

# The purpose was to verify a simple model could learn the short-range dependencies that 
# valid SMILES require.

# https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html was referenced

In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn

import random

import string
import time
import math

Tensor = torch.cuda.FloatTensor
from torch.autograd import Variable
import sys
import os
sys.path.append('/usr/local/lib/python3.7/site-packages/')
from rdkit import Chem

from collections import Counter

In [None]:
all_characters = string.ascii_letters + string.punctuation + string.digits# abcd.... [\]./, ....
n_characters = len(all_characters) + 2 # + EOS + PAD. 96
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Data is from bindingDB. it is the full list of SMILES. 

data = pd.read_csv('smile_data.csv')
average_length = data['Ligand SMILES'].str.len().mean()
data = data[data['Ligand SMILES'].map(lambda x: ' ' not in str(x))]
len_data = len(data)
print("Average length: {}".format(average_length))
print(data.head())

In [None]:
# Function to generate human-readable time
def time_since(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

# Generator for SMILE data. 
class smile_generator():
    def __init__(self):
        self.current_index = 0

    def get_next_batch(self, batch_size, sequential):
        if sequential is True:
            if len_data - self.current_index <= batch_size:
                self.current_index = 0
            to_return = data['Ligand SMILES'].iloc[self.current_index:self.current_index+batch_size]
            self.current_index += batch_size
        else:
            random_indice = np.random.randint(0, len_data-batch_size)
            to_return = data['Ligand SMILES'].iloc[random_indice:random_indice+batch_size]
        return list(to_return)

    def target_smile(self, smile):
        character_indices = [all_characters.find(smile[i]) for i in range(0, len(smile))]
        character_indices.append(n_characters - 2) # we added 2 earlier for EOS and PAD, subtract 1 for EOS
        return torch.LongTensor(character_indices)

    def get_n_samples(self, n=32, sequential=False):
        lst_targets = []
        
        if sequential == True:
            temp_smile = self.get_next_batch(n)
        else:
            temp_smile = self.get_next_batch(n, sequential=False)

        for i in range(n):
            lst_targets.append(self.target_smile(temp_smile[i]))
            
        pad_token = n_characters - 1 # last character
        lengths = [len(seq)-1 for seq in lst_targets]
        max_len = max(lengths)
        padded_Y = np.ones((n, max_len+1)) * pad_token
        
        for i, x_len in enumerate(lengths):
            tgt = lst_targets[i]

            padded_Y[i, 0:x_len+1] = tgt[:x_len+1]
            
        
        return padded_Y, lengths

In [None]:
max_len = 200

def sample(starting_token='C'):
    with torch.no_grad():
        inp = gen.target_smile(starting_token).to(device)
        hidden = lstm.sample_init()
        
        output_seq = starting_token
        
        for i in range(max_len):
            output, hidden = lstm(inp[0].unsqueeze(-1), 1, hidden)
            output = output.squeeze(1)
            topv, topi = output.topk(1)
            topi = topi[0][0]
            if topi == n_characters - 2:
                break
            elif topi == n_characters-1:
                char = '<P>'
                output_seq += char
            else:
                char = all_characters[topi]
                output_seq += char
            inp = gen.target_smile(char).to(device)
            
        unique_str = None
        if Chem.MolFromSmiles(output_seq, sanitize=False) is not None:
            prnt_str = "The SMILE is valid"
            if output_seq in data['Ligand SMILES']:
                unique_str = "This SMILE is in the dataset"
            else:
                unique_str = "This SMILE is unique"
        else:
            prnt_str = "The SMILE is not valid"
            
    return output_seq, prnt_str, unique_str

In [None]:
class LSTM(nn.Module):
    def __init__(self, hidden_size, batch_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding_dim = 256
        self.n_layers = 1
        self.batch_size = batch_size
        
        
        # takes in hidden units, outputs the prediction of char
        self.fc1 = nn.Linear(self.hidden_size, n_characters)
        self.dropout = nn.Dropout(0.1)
        self.softmax = nn.LogSoftmax(dim=1)
        
        self.build_model()

    def build_model(self):
        self.embedding = nn.Embedding(
            num_embeddings = n_characters,
            embedding_dim = self.embedding_dim,
            padding_idx = n_characters-1)
        
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.hidden_size,
            num_layers=self.n_layers,
            batch_first=True)
        
    def init_hidden(self):
        # hidden weights, size = nlayers, batch, hidden
        hidden_a = torch.randn(self.n_layers, self.batch_size, self.hidden_size)
        hidden_b = torch.randn(self.n_layers, self.batch_size, self.hidden_size)

        hidden_a = Variable(hidden_a).to(device)
        hidden_b = Variable(hidden_b).to(device)
        
        return (hidden_a, hidden_b)
    
    def sample_init(self):
        # hidden weights, size = nlayers, 1, hidden
        hidden_a = torch.randn(self.n_layers, 1, self.hidden_size)
        hidden_b = torch.randn(self.n_layers, 1, self.hidden_size)

        hidden_a = Variable(hidden_a).to(device)
        hidden_b = Variable(hidden_b).to(device)
        
        return (hidden_a, hidden_b)
    
    def forward(self, x, length, hidden):
        x = x.unsqueeze(-1)
        if x.size():
            batch_size, seq_len = x.size()
        else:
            batch_size, seq_len = 1, 1
        
        x = self.embedding(x)
        
        x, hidden = self.lstm(x, hidden)
                
        x = x.contiguous()
        x = x.view(batch_size, -1)
        
        x = self.fc1(x)
        
        x = self.softmax(x)
        
        x = x.view(batch_size, seq_len, n_characters)
        
        return x, hidden


In [None]:
criterion = nn.NLLLoss()

learning_rate = .005
gen = smile_generator()
def train():
    y_train, length = gen.get_n_samples(n=128)
    y_train, length = torch.LongTensor(y_train).to(device), torch.tensor(length).to(device)
    hidden = lstm.init_hidden()

    
    lstm.zero_grad()
    loss = 0

    for i in range(len(y_train[0])-1):
        output, hidden = lstm(y_train[:, i], length, hidden)
        output = output.squeeze(1) 
        l = criterion(output, y_train[:, i+1])
        loss += l
        
    loss.backward()

    #if nans in loss, uncomment
    #nn.utils.clip_grad_norm_(lstm.parameters(), 2)

    for p in lstm.parameters():
        p.data.add_(-learning_rate, p.grad.data)
        
    return loss.item() / y_train.size(0)

In [None]:
lstm = LSTM(1024, 128).to(device)

In [None]:
n_iters = 10000
print_every = 200
plot_every = 500
all_losses = []
total_loss = 0 # Reset every plot_every iters

start = time.time()
for iter in range(1, n_iters + 1):
    loss = train()
    total_loss += loss
    
    if iter % print_every == 0:
        print("Training time: {}. Iter: {}.".format(time_since(start), iter))
        print("Loss: {}".format(loss))
        SMILE, validity_str, unique_str = sample()
        print("Randomly generated sample: {}".format(SMILE))
        print(validity_str)
        if unique_str is not None:
            print(unique_str)
        

    if iter % plot_every == 0:
        all_losses.append(total_loss / plot_every)
        total_loss = 0