In [None]:
import os, sys, time
import json
import numpy as np
import torch
import pandas as pd
import selfies
from torch import nn
torch.backends.cudnn.benchmark = True

from utils import model_dir
from VAE import VAE_encode, VAE_decode
from train import train_model

In [None]:
path = '../Comp_Lib/'
print(model_dir)
os.makedirs(model_dir)

In [None]:
# hyperparameters

warm_start = True

latent_dimension = 100

settings = {'encoder':  {'layer_1d': 500,
                         'layer_2d': 200,
                         'layer_3d': 100,
                         'latent_dimension': latent_dimension},
            'decoder':  {'latent_dimension': latent_dimension,
                         'gru_neurons_num': 200,
                         'gru_stack_size': 3},
            'training': {'batch_size': 2500,
                         'latent_dimension': latent_dimension,
                         'KLD_alpha': 0.0001,
                         'lr_enc': 0.0001,
                         'lr_dec': 0.0001,
                         'num_epochs': 2000}
           }

In [None]:
# Read compound library information

data = np.load(path + '0.npy')

len_max_molec = data.shape[0]
len_alphabet = data.shape[1]
len_max_molec1Hot = len_max_molec * len_alphabet

encoder_parameter = settings['encoder']
decoder_parameter = settings['decoder']
training_parameters = settings['training']

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('-->' + str(device))

model_encode = VAE_encode(len_max_molec1Hot, **encoder_parameter).to(device)
model_decode = VAE_decode(len_alphabet, **decoder_parameter).to(device)

if warm_start:

    # warm start

    warm_path = './Models/2020-09-17 17-01-39'
    warm_epoch = 5500
    warm_model = os.path.join(warm_path, 'Epochs:', str(warm_epoch))

    encoder_state = torch.load(os.path.join(warm_model, 'encode.tar'))
    decoder_state = torch.load(os.path.join(warm_model, 'decode.tar'))

    model_encode.load_state_dict(encoder_state['state_dict'])
    model_decode.load_state_dict(decoder_state['state_dict'])

    model_encode.to(device)
    model_decode.to(device)

else:

    # cool start
    
    warm_epoch = None

In [None]:
# train the model

model_encode.train()
model_decode.train()

models = [model_encode, model_decode]

#idx = list(range(10000))
idx = list(range(1248664))
np.random.shuffle(idx)

file = json.dumps(settings)
with open(model_dir + '/settings.json', 'w') as content:
    content.write(file)

print("start training")
with open(model_dir + '/training_log.txt', 'w') as content:
    content.write("start training" + "\n")
    
modules = [models, idx, path, device, warm_epoch]
train_model(*modules, **training_parameters)

with open(model_dir + '/COMPLETED.txt', 'w') as content:
    content.write('exit code: 0')