# Description

This script is used to train our model on the known SMILES universe to learn how to generate new small molecules very accurately. We then use this initial network to generate our generation 0 (gen0) candidate molecules.

In [None]:
!sudo apt install python-rdkit
!pip install numpy
!wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
!chmod +x Miniconda3-latest-Linux-x86_64.sh
!time bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
!time conda install -q -y -c conda-forge rdkit
!pip install bunch
%matplotlib inline
import matplotlib.pyplot as plt
import sys
import os
sys.path.append('/usr/local/lib/python3.7/site-packages/')

## Train the Network

In [None]:
import tensorflow
tensorflow.test.is_gpu_available()

In [None]:
#Import the required libraries
import numpy as np
from copy import copy
import keras
from lstm_chem.utils.config import process_config
from lstm_chem.model import LSTMChem
from lstm_chem.generator import LSTMChemGenerator
from lstm_chem.trainer import LSTMChemTrainer
from lstm_chem.data_loader import DataLoader

In [None]:
CONFIG_FILE = 'experiments/2019-12-23/LSTM_Chem/config.json'
config = process_config(CONFIG_FILE)

In [None]:
modeler = LSTMChem(config, session='train')

In [None]:
train_dl = DataLoader(config, data_type='train')

In [None]:
valid_dl = copy(train_dl)
valid_dl.data_type = 'valid'

In [None]:
trainer = LSTMChemTrainer(modeler, train_dl, valid_dl)

In [None]:
trainer.train()

In [None]:
# Save the trained model
trainer.model.save_weights('experiments/2019-12-23/LSTM_Chem/checkpoints/LSTM_Chem-baseline-model-full.hdf5.hdf5')

## Load the model and GENERATE new molecules (SMILES)

In [None]:
config['model_weight_filename'] = 'experiments/2019-12-23/LSTM_Chem/checkpoints/LSTM_Chem-baseline-model-full.hdf5'
print(config)

In [None]:
modeler = LSTMChem(config, session='generate')
generator = LSTMChemGenerator(modeler)
print(config)

In [None]:
sample_number = 25000
sampled_smiles = generator.sample(num=sample_number)

In [None]:
from rdkit import RDLogger, Chem, DataStructs
from rdkit.Chem import AllChem, Draw, Descriptors
from rdkit.Chem.Draw import IPythonConsole
RDLogger.DisableLog('rdApp.*')

In [None]:
valid_mols = []
for smi in sampled_smiles:
    mol = Chem.MolFromSmiles(smi)
    if mol is not None:
        valid_mols.append(mol)
# low validity
print('Validity: ', f'{len(valid_mols) / sample_number:.2%}')

valid_smiles = [Chem.MolToSmiles(mol) for mol in valid_mols]
# high uniqueness
print('Uniqueness: ', f'{len(set(valid_smiles)) / len(valid_smiles):.2%}')

# Of valid smiles generated, how many are truly original vs ocurring in the training data
import pandas as pd
training_data = pd.read_csv('./datasets/all_smiles_clean.smi', header=None)
training_set = set(list(training_data[0]))
original = []
for smile in valid_smiles:
    if not smile in training_set:
        original.append(smile)
print('Originality: ', f'{len(set(original)) / len(set(valid_smiles)):.2%}')

In [None]:
with open('./generations/gen0.smi', 'w') as f:
    for item in valid_smiles:
        f.write("%s\n" % item)