<a href="https://colab.research.google.com/github/yingzibu/a_inhibitor_design/blob/main/examples/experiments/Molecule_RNN_08_25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! git clone https://github.com/shiwentao00/Molecule-RNN.git --quiet

In [3]:
! pip install rdkit --quiet
! pip install selfies --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.7/29.7 MB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
cd Molecule-RNN

/content/Molecule-RNN


In [1]:
import torch
import re
from tqdm import tqdm
import yaml
import selfies as sf
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [3]:
selfie_vocab_path = '/content/Molecule-RNN/vocab/selfies_merged_vocab.yaml'

class SELFIEVocab:
    def __init__(self, vocab_path):
        self.name ='selfies'
        with open(vocab_path, 'r') as f: self.vocab = yaml.full_load(f)
        self.int2token = {value:key for key, value in self.vocab.items()}
    def tokenize_smiles(self, mol): #mol is selfie actually
        ints = [self.vocab['<sos>']]
        selfies_list = list(sf.split_selfies(mol))
        ints += [self.vocab[token] for token in selfies_list]
        ints.append(self.vocab['<eos>'])
        return ints
    def combine_list(self, selfies): return "".join(selfies)

import pandas as pd

class SMILESDataset(Dataset):
    def __init__(self, smiles_file, percentage, vocab):
        super(SMILESDataset, self).__init__()
        assert (0<percentage<=1)
        self.percentage = percentage
        self.vocab = vocab
        if smiles_file.split('.')[-1] == 'smi':
            self.data = self.read_smiles_file(smiles_file)
        elif smiles_file.split('.')[-1] == 'csv':
            self.data = self.read_csv_file(smiles_file)
        print('total number of SMILES loaded: ', len(self.data))

        if self.vocab.name == 'selfies':
            self.data = self.smi2sf()
            print('total valid SELFIES:', len(self.data))
    def __getitem__(self, index):
        s_mol = self.data[index]
        vec_mol = self.vocab.tokenize_smiles(s_mol)
        return vec_mol

    def __len__(self): return len(self.data)

    def smi2sf(self):
        smiles = self.data
        sfs = []
        for i in tqdm(smiles, total=len(smiles)):
            try:
                s_ = sf.encoder(i)
                if s_ is not None: sfs.append(s_)
            except: pass
        return sfs



    def read_smiles_file(self, path):
        with open(path, 'r') as f:
            smiles = [line.strip("\n") for line in f.readlines()]
        num_data = len(smiles)
        return smiles[0:int(num_data*self.percentage)]

    def read_csv_file(self, path):
        data = pd.read_csv(path)
        for i in ['Smiles', 'smiles', 'SMILES']:
            if i in data.columns:
                smiles = data[i].tolist()
                num_data = len(smiles)
                return smiles[0:int(num_data*self.percentage)]

smile_dir = '/content/Molecule-RNN/dataset/chembl28-cleaned.smi'
selfie_vocab_path = '/content/Molecule-RNN/vocab/selfies_merged_vocab.yaml'

def dataloader_gen(percentage, batch_size, PADDING_IDX, shuffle,
                   dataset_dir=smile_dir, which_vocab='selfies',
                   vocab_path=selfie_vocab_path, drop_last=True):
    print('which_vocab:', which_vocab)
    if which_vocab == 'selfies':
        vocab = SELFIEVocab(vocab_path)
        dataset = SMILESDataset(dataset_dir, percentage, vocab)
    else:
        dataset = None
        vocab = None

    def pad_collate(batch):
        lengths = [len(x) for x in batch]
        batch = [torch.tensor(x, dtype=torch.long) for x in batch]
        x_padded = pad_sequence(batch, batch_first=True,
                                padding_value=PADDING_IDX)
        return x_padded, lengths

    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last,
                            collate_fn=pad_collate)
    return dataloader, len(dataset)

In [4]:
rnn_config = {
    'num_embeddings': 78,
    'embedding_dim': 256,
    'rnn_type': "GRU",

    "input_size": 256,
    'hidden_size': 512,
    'num_layers': 3,
    'dropout': 0,

    "batch_size": 512,
    'shuffle': True,
    'num_epoch': 10,
    'which_optimizer': 'adam',
    'learning_rate': 0.001,
    'weight_decay': 1.0e-4
}


In [31]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.functional import softmax

class RNN(nn.Module):
    def __init__(self, rnn_config):
        super(RNN, self).__init__()

        self.emb = nn.Embedding(num_embeddings = rnn_config['num_embeddings'],
                                embedding_dim = rnn_config['embedding_dim'],
                                padding_idx = rnn_config['num_embeddings'] - 1)

        if rnn_config['rnn_type'] == 'LSTM':
            self.rnn = nn.LSTM(input_size = rnn_config['input_size'],
                               hidden_size = rnn_config['hidden_size'],
                               num_layers = rnn_config['num_layers'],
                               batch_first=True, dropout=rnn_config['dropout'])
        elif rnn_config['rnn_type'] == 'GRU':
            self.rnn = nn.GRU(input_size = rnn_config['input_size'],
                            hidden_size = rnn_config['hidden_size'],
                            num_layers = rnn_config['num_layers'],
                            batch_first=True, dropout=rnn_config['dropout'])
        else: raise ValueError('rnn_type should be LSTM or GRU')

        self.linear = nn.Linear(rnn_config['hidden_size'],
                                rnn_config['num_embeddings']-2)

    def forward(self, data, lengths):
        # print('data shape: ', data.shape) # [batch_size, ...]
        embeddings = self.emb(data)
        # print('emb shape:', embeddings.shape) #[batch_size, ..., embedding_dim]
        embeddings = pack_padded_sequence(input=embeddings, lengths=lengths,
                                    batch_first=True, enforce_sorted=False)

        embeddings, _ = self.rnn(embeddings)
        embeddings = self.linear(embeddings.data)

        # print('here emb shpae:', embeddings.shape) # [batch_size*?, num_embedding-2]
        return embeddings

    def sample(self, batch_size, vocab, max_length=140):
        start_int = vocab.vocab['<sos>']
        if torch.cuda.is_available(): device='cuda'
        else: device = 'cpu'
        sos = torch.ones([batch_size, 1], dtype=torch.long, device=device)
        sos = sos * start_int
        output = []
        x = self.emb(sos)
        x, hidden = self.rnn(x)
        x = self.linear(x)
        x = softmax(x, dim=-1)
        x = torch.multinomial(x.squeeze(), 1)
        output.append(x)
        # x = sos
        finish = torch.zeros(batch_size, dtype=torch.bool).to(device)
        for _ in range(max_length):
            x = self.emb(x)
            x, hidden = self.rnn(x, hidden)
            x = self.linear(x)
            x = softmax(x, dim=-1)
            x = torch.multinomial(x.squeeze(), 1)
            output.append(x)
            eos_sampled = (x==vocab.vocab['<eos>']).data
            finish = torch.logical_or(finish, eos_sampled.squeeze())
            if torch.all(finish):
                return torch.cat(output, -1)
        return torch.cat(output, -1)

In [51]:
import yaml
import os
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from rdkit import Chem
import selfies as sf
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')

def sample(model, vocab, batch_size):
    model.eval()
    sample_ints = model.sample(batch_size=batch_size, vocab=vocab).tolist()
    molecules = []
    for ints in sample_ints:
        mol = []
        for x in ints:
            if vocab.int2token[x] == '<eos>': break
            else: mol.append(vocab.int2token[x])
        string_mol = "".join(mol)
        molecules.append(string_mol)
    if vocab.name == 'selfies':
        molecules = [sf.decoder(x) for x in molecules]
    return molecules


def compute_valid_rate(molecules):
    num_valid, num_invalid = 0, 0
    valid_mols = set()
    for i in molecules:
        mol = Chem.MolFromSmiles(i)
        if mol is None: num_invalid +=1
        else:
            num_valid += 1
            valid_mols.add(i)
    assert len(molecules) == num_valid + num_invalid
    unique_num = len(valid_mols)
    print(unique_num)
    valid_rate = num_valid / len(molecules)
    unique_rate = unique_num / len(molecules)
    return valid_rate, unique_rate


In [7]:
cuda = torch.cuda.is_available()
if cuda: device = 'cuda'
else: device = 'cpu'

out_dir = 'model_save/'
if not os.path.exists(out_dir): os.makedirs(out_dir)

with open(out_dir+'config.yaml', 'w') as f: yaml.dump(rnn_config, f)

In [8]:
smile_dir = '/content/Molecule-RNN/dataset/chembl28-cleaned.smi'
selfie_vocab_path = '/content/Molecule-RNN/vocab/selfies_merged_vocab.yaml'

In [56]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
SELFIES: a robust representation of semantically constrained graphs with an
    example application in chemistry (https://arxiv.org/abs/1905.13741)
    by Mario Krenn, Florian Haese, AkshatKuman Nigam, Pascal Friederich,
    Alan Aspuru-Guzik.

    Variational Autoencoder (VAE) for chemistry
        comparing SMILES and SELFIES representation using reconstruction
        quality, diversity and latent space validity as metrics of
        interest

information:
    ML framework: pytorch
    chemistry framework: RDKit

    get_selfie_and_smiles_encodings_for_dataset
        generate complete encoding (inclusive alphabet) for SMILES and
        SELFIES given a data file

    VAEEncoder
        fully connected, 3 layer neural network - encodes a one-hot
        representation of molecule (in SMILES or SELFIES representation)
        to latent space

    VAEDecoder
        decodes point in latent space using an RNN

    latent_space_quality
        samples points from latent space, decodes them into molecules,
        calculates chemical validity (using RDKit's MolFromSmiles), calculates
        diversity
"""

import os
import sys
import time

import numpy as np
import pandas as pd
import torch
import yaml
from rdkit import rdBase
from rdkit.Chem import MolFromSmiles
from torch import nn
import selfies as sf


# from data_loader import \
#     multiple_selfies_to_hot, multiple_smile_to_hot

rdBase.DisableLog('rdApp.error')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

"""
This file is to encode SMILES and SELFIES into one-hot encodings
"""



import pandas as pd
from os import walk
import torch
from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
from rdkit import Chem
m = Chem.MolFromSmiles
import math


def smile_to_hot(smile, largest_smile_len, alphabet):
    """Go from a single smile string to a one-hot encoding.
    """

    char_to_int = dict((c, i) for i, c in enumerate(alphabet))

    # pad with ' '
    smile += ' ' * (largest_smile_len - len(smile))

    # integer encode input smile
    integer_encoded = [char_to_int[char] for char in smile]

    # one hot-encode input smile
    onehot_encoded = list()
    for value in integer_encoded:
        letter = [0 for _ in range(len(alphabet))]
        letter[value] = 1
        onehot_encoded.append(letter)
    return integer_encoded, np.array(onehot_encoded)


def multiple_smile_to_hot(smiles_list, largest_molecule_len, alphabet):
    """Convert a list of smile strings to a one-hot encoding

    Returned shape (num_smiles x len_of_largest_smile x len_smile_encoding)
    """

    hot_list = []
    for s in smiles_list:
        _, onehot_encoded = smile_to_hot(s, largest_molecule_len, alphabet)
        hot_list.append(onehot_encoded)
    return np.array(hot_list)


def selfies_to_hot(selfie, largest_selfie_len, alphabet):
    """Go from a single selfies string to a one-hot encoding.
    """

    symbol_to_int = dict((c, i) for i, c in enumerate(alphabet))

    # pad with [nop]
    selfie += '[nop]' * (largest_selfie_len - sf.len_selfies(selfie))

    # integer encode
    symbol_list = sf.split_selfies(selfie)
    integer_encoded = [symbol_to_int[symbol] for symbol in symbol_list]

    # one hot-encode the integer encoded selfie
    onehot_encoded = list()
    for index in integer_encoded:
        letter = [0] * len(alphabet)
        letter[index] = 1
        onehot_encoded.append(letter)

    return integer_encoded, np.array(onehot_encoded)


def multiple_selfies_to_hot(selfies_list, largest_molecule_len, alphabet):
    """Convert a list of selfies strings to a one-hot encoding
    """

    hot_list = []
    for s in selfies_list:
        _, onehot_encoded = selfies_to_hot(s, largest_molecule_len, alphabet)
        hot_list.append(onehot_encoded)
    return np.array(hot_list)

def _make_dir(directory):
    os.makedirs(directory)


def save_models(encoder, decoder, epoch):
    out_dir = './saved_models/{}'.format(epoch)
    _make_dir(out_dir)
    torch.save(encoder, '{}/E'.format(out_dir))
    torch.save(decoder, '{}/D'.format(out_dir))


class VAEEncoder(nn.Module):

    def __init__(self, in_dimension, layer_1d, layer_2d, layer_3d,
                 latent_dimension):
        """
        Fully Connected layers to encode molecule to latent space
        """
        super(VAEEncoder, self).__init__()
        self.latent_dimension = latent_dimension

        # Reduce dimension up to second last layer of Encoder
        self.encode_nn = nn.Sequential(
            nn.Linear(in_dimension, layer_1d),
            nn.ReLU(),
            nn.Linear(layer_1d, layer_2d),
            nn.ReLU(),
            nn.Linear(layer_2d, layer_3d),
            nn.ReLU()
        )

        # Latent space mean
        self.encode_mu = nn.Linear(layer_3d, latent_dimension)

        # Latent space variance
        self.encode_log_var = nn.Linear(layer_3d, latent_dimension)

    @staticmethod
    def reparameterize(mu, log_var):
        """
        This trick is explained well here:
            https://stats.stackexchange.com/a/16338
        """
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x):
        """
        Pass throught the Encoder
        """
        # Get results of encoder network
        h1 = self.encode_nn(x)

        # latent space
        mu = self.encode_mu(h1)
        log_var = self.encode_log_var(h1)

        # Reparameterize
        z = self.reparameterize(mu, log_var)
        return z, mu, log_var


class VAEDecoder(nn.Module):

    def __init__(self, latent_dimension, gru_stack_size, gru_neurons_num,
                 out_dimension):
        """
        Through Decoder
        """
        super(VAEDecoder, self).__init__()
        self.latent_dimension = latent_dimension
        self.gru_stack_size = gru_stack_size
        self.gru_neurons_num = gru_neurons_num

        # Simple Decoder
        self.decode_RNN = nn.GRU(
            input_size=latent_dimension,
            hidden_size=gru_neurons_num,
            num_layers=gru_stack_size,
            batch_first=False)

        self.decode_FC = nn.Sequential(
            nn.Linear(gru_neurons_num, out_dimension),
        )

    def init_hidden(self, batch_size=1):
        weight = next(self.parameters())
        return weight.new_zeros(self.gru_stack_size, batch_size,
                                self.gru_neurons_num)

    def forward(self, z, hidden):
        """
        A forward pass throught the entire model.
        """

        # Decode
        l1, hidden = self.decode_RNN(z, hidden)
        decoded = self.decode_FC(l1)  # fully connected layer

        return decoded, hidden


def is_correct_smiles(smiles):
    """
    Using RDKit to calculate whether molecule is syntactically and
    semantically valid.
    """
    if smiles == "":
        return False

    try:
        return MolFromSmiles(smiles, sanitize=True) is not None
    except Exception:
        return False


def sample_latent_space(vae_encoder, vae_decoder, sample_len):
    vae_encoder.eval()
    vae_decoder.eval()

    gathered_atoms = []

    fancy_latent_point = torch.randn(1, 1, vae_encoder.latent_dimension,
                                     device=device)
    hidden = vae_decoder.init_hidden()

    # runs over letters from molecules (len=size of largest molecule)
    for _ in range(sample_len):
        out_one_hot, hidden = vae_decoder(fancy_latent_point, hidden)

        out_one_hot = out_one_hot.flatten().detach()
        soft = nn.Softmax(0)
        out_one_hot = soft(out_one_hot)

        out_index = out_one_hot.argmax(0)
        gathered_atoms.append(out_index.data.cpu().tolist())

    vae_encoder.train()
    vae_decoder.train()

    return gathered_atoms

import random
import math
from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
from rdkit import Chem
mol_conv = Chem.MolFromSmiles

def latent_space_quality(vae_encoder, vae_decoder, type_of_encoding,
                         alphabet, sample_num, sample_len, epoch):
    total_correct = 0
    all_correct_molecules = set()
    print(f"latent_space_quality:"
          f" Take {sample_num} samples from the latent space")

    for _ in tqdm(range(1, sample_num + 1), total=sample_num):

        molecule_pre = ''
        for i in sample_latent_space(vae_encoder, vae_decoder, sample_len):
            molecule_pre += alphabet[i]
        molecule = molecule_pre.replace(' ', '')

        if type_of_encoding == 1:  # if SELFIES, decode to SMILES
            molecule = sf.decoder(molecule)

        if is_correct_smiles(molecule):
            total_correct += 1
            all_correct_molecules.add(molecule)

    MAX_GRID = 16
    show_num = min(len(all_correct_molecules), MAX_GRID)
    selected_smiles = random.choices([*all_correct_molecules], k=show_num)
    molecules = [mol_conv(i) for i in selected_smiles]

    img = MolsToGridImage(
        [m for m in molecules if m is not None],
        molsPerRow=int(math.sqrt(MAX_GRID)),
        subImgSize=(250, 150),  returnPNG=False
        # legends=legends,
    )
    display(img)
    import os
    if not os.path.exists('figs/'): os.makedirs('figs/')
    img.save(f'figs/latent_sample_{epoch}.png')
    print('figure saved at :', 'figs/', f'latent_sample_{epoch}.png')

    return total_correct, len(all_correct_molecules)


def quality_in_valid_set(vae_encoder, vae_decoder, data_valid, batch_size):
    data_valid = data_valid[torch.randperm(data_valid.size()[0])]  # shuffle
    num_batches_valid = len(data_valid) // batch_size

    quality_list = []
    for batch_iteration in range(min(25, num_batches_valid)):

        # get batch
        start_idx = batch_iteration * batch_size
        stop_idx = (batch_iteration + 1) * batch_size
        batch = data_valid[start_idx: stop_idx]
        _, trg_len, _ = batch.size()

        inp_flat_one_hot = batch.flatten(start_dim=1)
        latent_points, mus, log_vars = vae_encoder(inp_flat_one_hot)

        latent_points = latent_points.unsqueeze(0)
        hidden = vae_decoder.init_hidden(batch_size=batch_size)
        out_one_hot = torch.zeros_like(batch, device=device)
        for seq_index in range(trg_len):
            out_one_hot_line, hidden = vae_decoder(latent_points, hidden)
            out_one_hot[:, seq_index, :] = out_one_hot_line[0]

        # assess reconstruction quality
        quality = compute_recon_quality(batch, out_one_hot)
        quality_list.append(quality)

    return np.mean(quality_list).item()

import shutil
def train_model(vae_encoder, vae_decoder,
                data_train, data_valid, num_epochs, batch_size,
                lr_enc, lr_dec, KLD_alpha,
                sample_num, sample_len, alphabet, type_of_encoding):
    """
    Train the Variational Auto-Encoder
    """

    print('num_epochs: ', num_epochs)

    # initialize an instance of the model
    optimizer_encoder = torch.optim.Adam(vae_encoder.parameters(), lr=lr_enc)
    optimizer_decoder = torch.optim.Adam(vae_decoder.parameters(), lr=lr_dec)

    data_train = data_train.clone().detach().to(device)
    num_batches_train = int(len(data_train) / batch_size)

    quality_valid_list = [0, 0, 0, 0]
    folder_names = []
    for epoch in range(num_epochs):

        data_train = data_train[torch.randperm(data_train.size()[0])]

        start = time.time()
        for batch_iteration in range(num_batches_train):  # batch iterator

            # manual batch iterations
            start_idx = batch_iteration * batch_size
            stop_idx = (batch_iteration + 1) * batch_size
            batch = data_train[start_idx: stop_idx]

            # reshaping for efficient parallelization
            inp_flat_one_hot = batch.flatten(start_dim=1)
            latent_points, mus, log_vars = vae_encoder(inp_flat_one_hot)

            # initialization hidden internal state of RNN (RNN has two inputs
            # and two outputs:)
            #    input: latent space & hidden state
            #    output: one-hot encoding of one character of molecule & hidden
            #    state the hidden state acts as the internal memory
            latent_points = latent_points.unsqueeze(0)
            hidden = vae_decoder.init_hidden(batch_size=batch_size)

            # decoding from RNN N times, where N is the length of the largest
            # molecule (all molecules are padded)
            out_one_hot = torch.zeros_like(batch, device=device)
            for seq_index in range(batch.shape[1]):
                out_one_hot_line, hidden = vae_decoder(latent_points, hidden)
                out_one_hot[:, seq_index, :] = out_one_hot_line[0]

            # compute ELBO
            loss = compute_elbo(batch, out_one_hot, mus, log_vars, KLD_alpha)

            # perform back propogation
            optimizer_encoder.zero_grad()
            optimizer_decoder.zero_grad()
            loss.backward(retain_graph=True)
            nn.utils.clip_grad_norm_(vae_decoder.parameters(), 0.5)
            optimizer_encoder.step()
            optimizer_decoder.step()

            if (batch_iteration+1) % int(num_batches_train/3) == 0:
                end = time.time()

                # assess reconstruction quality
                quality_train = compute_recon_quality(batch, out_one_hot)
                quality_valid = quality_in_valid_set(vae_encoder, vae_decoder,
                                                     data_valid, batch_size)

                report = 'Epoch: %d,  Batch: %d / %d, (loss: %.3f | ' \
                         'quality: %.3f | quality_valid: %.3f)   ' \
                         'ELAPSED TIME: %.5f' \
                         % (epoch, batch_iteration, num_batches_train,
                            loss.item(), quality_train, quality_valid,
                            end - start)
                print(report)
                start = time.time()

        quality_valid = quality_in_valid_set(vae_encoder, vae_decoder,
                                             data_valid, batch_size)
        quality_valid_list.append(quality_valid)

        # only measure validity of reconstruction improved
        quality_increase = len(quality_valid_list) \
                           - np.argmax(quality_valid_list)
        if quality_increase == 1 and quality_valid_list[-1] > 50.:
            corr, unique = latent_space_quality(vae_encoder, vae_decoder,
                                                type_of_encoding, alphabet,
                                                sample_num, sample_len, epoch)

            model_path = f'vae/epoch_{epoch}/'
            if not os.path.exists(model_path): os.makedirs(model_path)
            folder_names.append(model_path)
            torch.save(vae_encoder.state_dict(), model_path+'enc.pt')
            torch.save(vae_decoder.state_dict(), model_path+'dec.pt')
            while len(folder_names) > 10:
                try:
                    remove_dir_name = folder_names.pop(0)
                    print('remove dir: ', remove_dir_name)
                    # os.rmdir(remove_dir_name) # cannot delete non empty dir

                    shutil.rmtree(remove_dir_name, ignore_errors=True)
                except:
                    print(f'error when delete {folder_names[0]}')
        else:
            corr, unique = -1., -1.

        report = 'Validity: %.3f %% | Diversity: %.3f %% | ' \
                 'Reconstruction: %.3f %%' \
                 % (corr * 100. / sample_num, unique * 100. / sample_num,
                    quality_valid)
        print(report)

        with open('results.dat', 'a') as content:
            content.write(report + '\n')

        if quality_valid_list[-1] < 70. and epoch > 200:
            break

        if quality_increase > 20:
            print('Early stopping criteria')
            break


def compute_elbo(x, x_hat, mus, log_vars, KLD_alpha):
    inp = x_hat.reshape(-1, x_hat.shape[2])
    target = x.reshape(-1, x.shape[2]).argmax(1)

    criterion = torch.nn.CrossEntropyLoss()
    recon_loss = criterion(inp, target)
    kld = -0.5 * torch.mean(1. + log_vars - mus.pow(2) - log_vars.exp())

    return recon_loss + KLD_alpha * kld


def compute_recon_quality(x, x_hat):
    x_indices = x.reshape(-1, x.shape[2]).argmax(1)
    x_hat_indices = x_hat.reshape(-1, x_hat.shape[2]).argmax(1)

    differences = 1. - torch.abs(x_hat_indices - x_indices)
    differences = torch.clamp(differences, min=0., max=1.).double()
    quality = 100. * torch.mean(differences)
    quality = quality.detach().cpu().numpy()

    return quality



def main(file_name_smiles, type_of_encoding, dynamic_z_dim=False):
    content = open('logfile.dat', 'w')
    content.close()
    content = open('results.dat', 'w')
    content.close()

    if os.path.exists("settings.yml"):
        settings = yaml.safe_load(open("settings.yml", "r"))
    else:
        print("Expected a file settings.yml but didn't find it.")
        return

    print('--> Acquiring data...')
    # type_of_encoding = settings['data']['type_of_encoding']
    # file_name_smiles = settings['data']['smiles_file']

    print('Finished acquiring data.')

    if type_of_encoding == 0:
        print('Representation: SMILES')
        print('There are :', len(file_name_smiles), ' compounds')
        _, _, _, encoding_list, encoding_alphabet, largest_molecule_len = \
            get_selfie_and_smiles_encodings_for_dataset(file_name_smiles)

        print('--> Creating one-hot encoding...')
        data = multiple_smile_to_hot(encoding_list, largest_molecule_len,
                                     encoding_alphabet)
        print('Finished creating one-hot encoding.')

    elif type_of_encoding == 1:
        print('Representation: SELFIES')
        encoding_list, encoding_alphabet, largest_molecule_len, _, _, _ = \
            get_selfie_and_smiles_encodings_for_dataset(file_name_smiles)

        print('--> Creating one-hot encoding...')
        data = multiple_selfies_to_hot(encoding_list, largest_molecule_len,
                                       encoding_alphabet)
        print('Finished creating one-hot encoding.')

    else:
        print("type_of_encoding not in {0, 1}.")
        return
    len_compound = data.shape[0]
    len_max_molec = data.shape[1]
    len_alphabet = data.shape[2]
    len_max_mol_one_hot = len_max_molec * len_alphabet

    print(' ')
    print(f"Dataset has {len_compound} compounds, "
          f"Alphabet has {len_alphabet} letters, "
          f"largest molecule is {len_max_molec} letters.")

    data_parameters = settings['data']
    batch_size = data_parameters['batch_size']

    if dynamic_z_dim:
        z_dim = int(len_max_mol_one_hot/300)*100
        settings['encoder']['latent_dimension'] = z_dim
        settings['decoder']['latent_dimension'] = z_dim
        print(f'z_dim {z_dim} is set to ~ 1/3 of in_dim {len_max_mol_one_hot}')
    print(settings['encoder'])
    encoder_parameter = settings['encoder']
    decoder_parameter = settings['decoder']
    training_parameters = settings['training']

    vae_encoder = VAEEncoder(in_dimension=len_max_mol_one_hot,
                             **encoder_parameter).to(device)
    vae_decoder = VAEDecoder(**decoder_parameter,
                             out_dimension=len(encoding_alphabet)).to(device)

    print('*' * 15, ': -->', device)

    data = torch.tensor(data, dtype=torch.float).to(device)

    train_valid_test_size = [0.5, 0.5, 0.0]
    data = data[torch.randperm(data.size()[0])]
    idx_train_val = int(len(data) * train_valid_test_size[0])
    idx_val_test = idx_train_val + int(len(data) * train_valid_test_size[1])

    data_train = data[0:idx_train_val]
    data_valid = data[idx_train_val:idx_val_test]

    print("start training")
    train_model(**training_parameters,
                vae_encoder=vae_encoder,
                vae_decoder=vae_decoder,
                batch_size=batch_size,
                data_train=data_train,
                data_valid=data_valid,
                alphabet=encoding_alphabet,
                type_of_encoding=type_of_encoding,
                sample_len=len_max_molec)

    with open('COMPLETED', 'w') as content:
        content.write('exit code: 0')


# if __name__ == '__main__':
#     try:
#         main()
#     except AttributeError:
#         _, error_message, _ = sys.exc_info()
#         print(error_message)

In [54]:
def read_smiles_file(path):
    with open(path, 'r') as f:
        smiles = [line.strip("\n") for line in f.readlines()]
    num_data = len(smiles)
    return smiles

smiles_all = read_smiles_file(smile_dir)
len(smiles_all)

538247

In [59]:
a, b, c, d, e, f = get_selfie_and_smiles_encodings_for_dataset(smiles_all)

--> Translating SMILES to SELFIES...
Finished translating 538247 SMILES to SELFIES.


In [None]:
encoding_list = a
encoding_alphabet = b
largest_molecule_len = c
data = multiple_selfies_to_hot(encoding_list, largest_molecule_len,
                                       encoding_alphabet)
len_compound = data.shape[0]
len_max_molec = data.shape[1]
len_alphabet = data.shape[2]
len_max_mol_one_hot = len_max_molec * len_alphabet

print(' ')
print(f"Dataset has {len_compound} compounds, "
        f"Alphabet has {len_alphabet} letters, "
        f"largest molecule is {len_max_molec} letters.")

In [58]:

def get_selfie_and_smiles_encodings_for_dataset(df):
    """
    Returns encoding, alphabet and length of largest molecule in SMILES and
    SELFIES, given a file containing SMILES molecules.

    input:
        csv file with molecules. Column's name must be 'smiles'.
    output:
        - selfies encoding
        - selfies alphabet
        - longest selfies string
        - smiles encoding (equivalent to file content)
        - smiles alphabet (character based)
        - longest smiles string
    """
    # if isinstance(file_path_or_df, str):
    #     file_path = file_path_or_df
    #     df = pd.read_csv(file_path)
    # else: df = file_path_or_df

    smiles_list = np.asanyarray(df)

    smiles_alphabet = list(set(''.join(smiles_list)))
    smiles_alphabet.append(' ')  # for padding

    largest_smiles_len = len(max(smiles_list, key=len))

    print('--> Translating SMILES to SELFIES...')
    selfies_list = list(map(sf.encoder, smiles_list))

    all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list)
    all_selfies_symbols.add('[nop]')
    selfies_alphabet = list(all_selfies_symbols)

    largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list)

    print(f'Finished translating {len(selfies_list)} SMILES to SELFIES.')

    return selfies_list, selfies_alphabet, largest_selfies_len, \
           smiles_list, smiles_alphabet, largest_smiles_len


In [9]:
print(rnn_config)
dataset_dir = smile_dir
which_vocab = 'selfies'
vocab_path = selfie_vocab_path
percentage = 0.5

{'num_embeddings': 78, 'embedding_dim': 256, 'rnn_type': 'GRU', 'input_size': 256, 'hidden_size': 512, 'num_layers': 3, 'dropout': 0, 'batch_size': 512, 'shuffle': True, 'num_epoch': 10, 'which_optimizer': 'adam', 'learning_rate': 0.001, 'weight_decay': 0.0001}


In [10]:
batch_size = rnn_config['batch_size']
shuffle = rnn_config['shuffle']
PADDING_IDX = rnn_config['num_embeddings']-1
num_workers = os.cpu_count()
print('number of workers to load data:', num_workers)
print('which vocab to use: ', which_vocab)
vocab = SELFIEVocab(selfie_vocab_path)

number of workers to load data: 2
which vocab to use:  selfies


In [11]:
dataloader, train_size = dataloader_gen(percentage, batch_size, PADDING_IDX,
                                        shuffle, drop_last=False)

which_vocab: selfies
total number of SMILES loaded:  269123


100%|██████████| 269123/269123 [02:59<00:00, 1501.23it/s]

total valid SELFIES: 269123





In [32]:
model = RNN(rnn_config).to(device)
learning_rate = rnn_config['learning_rate']
weight_decay = rnn_config['weight_decay']
loss_function = nn.CrossEntropyLoss(reduction='sum')

if rnn_config['which_optimizer'] == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                                 weight_decay=weight_decay, amsgrad=True)
elif rnn_config['which_optimizer'] == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,
                                weight_decay=weight_decay, momentum=0.9)
elif rnn_config['which_optimizer'] == 'adamw':
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate,
                                 weight_decay=weight_decay, amsgrad=True)
else: raise ValueError('wrong optimizer. should be adam, sgd or adamw')

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5,
                              cooldown=10, min_lr=1e-4, verbose=True)

In [13]:
import torch
def save_model(model, path):
    torch.save(model.state_dict(), path)

def load_model(model, path):
    cuda = torch.cuda.is_available()
    if cuda: device = 'cuda'
    else: device = 'cpu'
    model.load_state_dict(torch.load(path, map_location=device))

In [33]:
train_losses = []
best_valid_rate = 0
num_epoch = rnn_config['num_epoch']
print(f'begin training {num_epoch} epochs')
model_names = []
best_epoch = 0

begin training 10 epochs


In [37]:
if best_epoch != 0:
    load_model(model, out_dir + f'epoch_{best_epoch}.pt')


for epoch in range(1, 1+num_epoch):
    model.train()
    train_loss = 0
    for data, lengths in tqdm(dataloader, total=len(dataloader),
                              desc=f'epoch {epoch}'): # batch_size = 512
        # print(lengths)
        lengths = [l-1 for l in lengths]
        optimizer.zero_grad()
        data = data.to(device)
        preds = model(data, lengths)
        targets = pack_padded_sequence(
            data[:, 1:], lengths, batch_first=True, enforce_sorted=False).data
        # preds: [num, len_alphabet-2], target: [num]
        # print('hre pred, target, ', preds.shape, targets.shape)
        loss = loss_function(preds, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss/train_size)
    print(f'epoch {epoch}, train_loss: {train_losses[-1]:.3f}')
    scheduler.step(train_losses[-1])
    sampled_molecules = sample(model, vocab, batch_size=1024)
    valid, unique = compute_valid_rate(sampled_molecules)
    print(f'valid:  {(valid * 100):.3f} % | unique: {(unique * 100):.3f} %')
    if valid > best_valid_rate:
        best_epoch = epoch
        trained_model_dir = out_dir + f'epoch_{epoch}.pt'
        print(f'save model: ', trained_model_dir)
        save_model(model, trained_model_dir)
        model_names.append(trained_model_dir)
        while len(model_names) > 10:
            try:
                remove_name = model_names.pop(0)
                print('remove ', remove_name)
                os.remove(remove_name)
            except: print(f'error remove {model_names[0]}')


epoch 1: 100%|██████████| 526/526 [02:28<00:00,  3.55it/s]


epoch 1, train_loss: 43.709
0
valid:  100.000 % | unique: 0.000 %
save model:  model_save/epoch_1.pt


epoch 2: 100%|██████████| 526/526 [02:32<00:00,  3.44it/s]


epoch 2, train_loss: 39.369
0
valid:  100.000 % | unique: 0.000 %
save model:  model_save/epoch_2.pt


epoch 3: 100%|██████████| 526/526 [02:27<00:00,  3.58it/s]


epoch 3, train_loss: 36.672
0
valid:  100.000 % | unique: 0.000 %
save model:  model_save/epoch_3.pt


epoch 4:  20%|█▉        | 104/526 [00:29<01:59,  3.53it/s]


KeyboardInterrupt: ignored

In [50]:


compute_valid_rate(sampled_molecules)

1024


(1.0, 1.0)

In [49]:
dataset

NameError: ignored

In [41]:
valid_mols

()

In [16]:
best_dir = out_dir + f'epoch_{best_epoch}.pt'
with open(out_dir +'config.yaml', 'r') as f: config = yaml.full_load(f)
model = RNN(config).to(device)
load_model(model, best_dir)
model.eval()
num_batches = 10
for _ in tqdm(range(num_batches)):
    sampled_ints = model.sample(batch_size=batch_size,
                                vocab=vocab, device=device).tolist()
    mols = []
    for ints in sampled_ints:
        s_mol = []
        for x in ints:
            if vocab.int2token[x] =='<eos>':break
            else: s_mol.append(vocab.int2token[x])
        mols.append("".join(s_mol))
    mols = [sf.decoder(x) for x in mols]


512

In [132]:
selfie_vocab_path = '/content/Molecule-RNN/vocab/selfies_merged_vocab.yaml'
with open(selfie_vocab_path, 'r') as f:
    vocab = yaml.full_load(f)

In [133]:
vocab = SELFIEVocab(selfie_vocab_path)

In [15]:
import pandas as pd
df = pd.DataFrame([2, 3], columns=['SMILES'])
df

Unnamed: 0,SMILES
0,2
1,3


In [57]:
x = torch.randn([3, 4])
x

tensor([[-0.1928,  1.4340, -0.1538,  1.4391],
        [ 0.1171, -0.4854, -1.4899,  1.0256],
        [ 1.6104,  0.7337, -0.1302, -1.3865]])

In [68]:
x

tensor([[0.0817, 0.4156, 0.0849, 0.4178],
        [0.2365, 0.1295, 0.0474, 0.5866],
        [0.6092, 0.2535, 0.1069, 0.0304]])

In [73]:
d = [x, x]
torch.cat(d, -1).shape

torch.Size([3, 8])

In [47]:
x.view(x.shape[0], -1).shape

torch.Size([3, 8])

In [65]:
x_ = torch.multinomial(x.squeeze(), 1)
x_

tensor([[3],
        [1],
        [0]])

In [53]:
torch.ones([12, 1])*76

tensor([[76.],
        [76.],
        [76.],
        [76.],
        [76.],
        [76.],
        [76.],
        [76.],
        [76.],
        [76.],
        [76.],
        [76.]])

In [26]:
vocab.tokenizer_smiles(sfs)

[76, 14, 14, 63, 75]

In [None]:
vocab