# 1
Some functions that might be useful

In [8]:
import torch
import torch.nn as nn
import numpy as np

def batches_gen(smiles, batchsize, encoder):
    '''Create a generator that returns batches of size (batch_size,seq_leng,nchars) from smiles, 
    where seq_leng is the length of the longest smiles string and nchar is the length of one-hot encoded characters (17)
       
       Arguments
       ---------
       smiles: python list(nsmiles,nchar) smiles array shape you want to make batches from
       batchsize: Batch size, the number of sequences per batch
       encoder: one hot encoder

    '''
    arr=[torch.tensor(np.array(encoder.transform(np.array(s).reshape(-1,1)).toarray()),dtype=torch.float) for s in smiles] 
        #size (nsmiles,seq_length(variable),nchars)
        
    # The features
    X = [s[:-1,:] for s in arr]
    # The targets, shifted by one
    y = [s[1:,:] for s in arr]
    # pad sequence so that all smiles are the same length
    X = nn.utils.rnn.pad_sequence(X,batch_first=True)
    y = nn.utils.rnn.pad_sequence(y,batch_first=True)

    
    for i in range(len(arr)//batchsize):
        yield X[i*batchsize:(i+1)*batchsize],y[i*batchsize:(i+1)*batchsize]
        
    #drop last batch that is not the same size due to hidden state constraint

    
   



In [35]:
# Defining a method to generate the next character
def predict(net, inputs, h, top_k=None):
        ''' Given a onehot encoded character, predict the next character.
            Returns the predicted onehot encoded character and the hidden state.
        Arguments:
            net: the lstm model
            inputs: input to the lstm model. shape (batch, time_step/length_of_smiles, input_size) with batchsize of 1
            h: hidden state (h,c)
            top_k: int. sample from top k possible characters
            
        '''
        # detach hidden state from history
        h = tuple([each.data for each in h])
        # get the output of the model
        out, h = net(inputs, h)
        # get the character probabilities
        p = out.data

        # get top characters
        if top_k is None:
            top_ch = np.arange(len(net.chars)) #index to choose from
        else:
            p, top_ch = p.topk(top_k)
            top_ch = top_ch.numpy().squeeze()
        # select the likely next character with some element of randomness
        p = p.numpy().squeeze()
        char = np.random.choice(top_ch, p=p/p.sum())
        # return the onehot encoded value of the predicted char and the hidden state
        output = np.zeros(inputs.detach().numpy().shape)
        output[:,:,char] = 1
        output = torch.tensor(output,dtype=torch.float)
        return output, h

# Declaring a method to generate new text
def sample(net, encoder, prime=['SOS'], top_k=None):
    """generate a smiles string starting from prime. I use 'SOS' (start of string) and 'EOS'(end of string). 
    You may need to change this based on your starting and ending character.

    """
    net.eval() # eval mode
    # get initial hidden state with batchsize 1
    h = net.init_state(1)
    # First off, run through the prime characters
    chars=[]
    for ch in prime:
        ch = encoder.transform(np.array([ch]).reshape(-1, 1)).toarray() #(1,17)
        ch = torch.tensor(ch,dtype=torch.float).reshape(1,1,17)
        char, h = predict(net, ch, h, top_k=top_k)
    chars.append(char)
    end  = encoder.transform(np.array(['EOS']).reshape(-1, 1)).toarray()
    end = torch.tensor(end,dtype=torch.float).reshape(1,1,17)

    # Now pass in the previous character and get a new one
    while not torch.all(end.eq(chars[-1])):
        char, h = predict(net, chars[-1], h, top_k=top_k)
        chars.append(char)
    chars =[c.detach().numpy() for c in chars]
    chars = np.array(chars).reshape(-1,17)
    chars = encoder.inverse_transform(chars).reshape(-1)
    return ''.join(chars[:-1])

A website to check if your smiles is valid: https://chemwriter.com/smiles/ It'll show you a figure for the valid string!