In [1]:
import mxnet as mx
from mxnet import gluon, autograd, nd
import gc
import time
import os
import numpy as np
import json
import argparse
import sys
import models, data
from data import get_mol_spec
import pandas as pd

In [2]:
file_name='datasets/input_Norm.txt'
ckpt_dir='ckpt'
ctx = mx.cpu()
cond = data.Delimited()
N_C = 7
batch_size=5
batch_size_test=15
num_workers=0
num_folds = 5
fold_id = 0
k=3
p=0.8
F_e=16
F_h=(32, 64, 128, 128, 256, 256)
F_skip=256
F_c=(512, )
Fh_policy=128
activation='relu'
N_rnn=3
lr=1e-3
wd = 0.0005
clip_grad=3.0
iterations=1000
summary_step=10

with open(file_name) as f:
    dataset = data.Lambda(f.readlines(), lambda _x:_x.strip('\n').strip('\r'))
if all([os.path.isfile(os.path.join(ckpt_dir, _n)) for _n in ['log.out', 'ckpt.params', 'trainer.status']]):
    is_continuous = True
    print("continue")
else:
    is_continuous = False


In [3]:
db_train = data.KFold(dataset, k=num_folds, fold_id=fold_id, is_train=True)
db_test = data.KFold(dataset, k=num_folds, fold_id=fold_id, is_train=False)
sampler_train = data.BalancedSampler(cost=[len(l.split('\t')[0]) for l in db_train], batch_size=batch_size)
loader_train = data.CMolRNNLoader(db_train, batch_sampler=sampler_train, num_workers=num_workers,k=k, p=p, conditional=cond,prefetch=2)
sampler_test = data.BalancedSampler(cost=[len(l.split('\t'[0])) for l in db_test], batch_size=batch_size_test)
loader_test = data.CMolRNNLoader(db_test, batch_sampler=sampler_test, num_workers=num_workers,k=k, p=p, conditional=cond,prefetch=2)

it_train, it_test = iter(loader_train), iter(loader_test)

In [4]:
if not is_continuous:
    configs = {'N_C': N_C,
                'F_e': F_e,
                'F_h': F_h,
                'F_skip': F_skip,
                'F_c': F_c,
                'Fh_policy': Fh_policy,
                'activation': activation,
                'rename': True,
                'N_rnn': N_rnn}
    with open(os.path.join(ckpt_dir, 'configs.json'), 'w') as f:
        json.dump(configs, f)
else:
    with open(os.path.join(ckpt_dir, 'configs.json')) as f:
        configs = json.load(f)

In [5]:
model = models.CVanillaMolGen_RNN(get_mol_spec().num_atom_types, get_mol_spec().num_bond_types, D=2, **configs)

In [6]:
if not is_continuous:
    model.collect_params().initialize(mx.init.Xavier(magnitude=1), force_reinit=True,ctx=ctx)
if is_continuous:
    print("continue")
    model.collect_params().initialize(mx.init.Xavier(magnitude=1), force_reinit=True,ctx=ctx)
    model.load_parameters(os.path.join(ckpt_dir, 'ckpt.params'),ctx=ctx)

In [7]:
opt = mx.optimizer.Adam(learning_rate=lr, wd=wd)
trainer = gluon.Trainer(model.collect_params(), opt)
if is_continuous:
    trainer.load_states(os.path.join(ckpt_dir, 'trainer.status'))

In [8]:
if not is_continuous:
    t0 = time.time()
    global_counter = 0
    epochs = 0
else:
    with open(os.path.join(ckpt_dir, 'log.out')) as f:
        records = f.readlines()
        if records[-1] != 'Training finished\n':
            final_record = records[-1]
            print(final_record)
        else:
            final_record = records[-2]
            print(final_record)
    count, epochs, t_final = int(final_record.split('\t')[0]), float(final_record.split('\t')[1]) ,float(final_record.split('\t')[2])
    t0 = time.time() - t_final
    global_counter = count

In [None]:
batch_num = len(loader_train)
epoch_size = 200
decay=0.015
decay_step= 10000
compare_val_loss = 10000
model_path = ckpt_dir + '{epoch:02d}-{val_loss:4f}.ckpt'
mx.random.seed(129)
with open(os.path.join(ckpt_dir, 'log.out'),mode='w' if not is_continuous else 'a') as f:
    if not is_continuous:
        f.write('step\tepochs\ttime(s)\tloss\tlr\tval_loss\n')
    losses = []    
    while True:
        t1 = time.time()
        global_counter += 1
        try:
            inputs = next(it_train)
        except StopIteration:
            it_train = iter(loader_train)
            inputs = next(it_train)
        inputs = data.CMolRNNLoader.from_numpy_to_tensor(inputs,ctx=ctx)
        with autograd.record():
            loss = model(*inputs)
            loss = sum(loss)
            loss.backward()
        nd.waitall()
        gc.collect()
        losses.append(loss)
        trainer.step(1, ignore_stale_grad=True)
        if global_counter % 100 == 0:
            print(str(global_counter)+"  "+str(loss))
        if global_counter % decay_step == 0:
            trainer.set_learning_rate(trainer.learning_rate * (1.0 - decay))
        if global_counter % batch_num == 0:
            epochs += 1
            mean_loss = np.mean([l.asscalar() for l in losses])
            print("mean_loss",mean_loss)
            val_losses = []
            ctx.empty_cache ()
            for i in range(len(loader_test)):
                try:
                    val_inputs = next(it_test)
                except StopIteration:
                    it_test = iter(loader_test)
                    val_inputs = next(it_test)
                val_inputs = data.CMolRNNLoader.from_numpy_to_tensor(val_inputs,ctx=ctx)
                val_loss =  model(*val_inputs)
                val_loss = sum(val_loss)
                nd.waitall()
                gc.collect()
                val_losses.append(val_loss)
            mean_val_loss = np.mean([l.asscalar() for l in val_losses])
            ctx.empty_cache ()
            #print("val_fin")
            model.save_parameters(os.path.join(ckpt_dir, 'ckpt.params'))
            #print("save")
            trainer.save_states(os.path.join(ckpt_dir, 'trainer.status'))
            #print("state_save")
            f.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(global_counter,epochs,float(time.time() - t0), mean_loss, trainer.learning_rate, mean_val_loss))
            f.flush()
            print("< epochs = ",int(epochs), "| loss = ", mean_loss, "| time = " ,time.time()-t0)#, "| val_loss", mean_val_loss ," >")
            if mean_val_loss < compare_val_loss:
                model.save_parameters(model_path.format(epoch=int(epochs), val_loss=mean_val_loss))
                compare_val_loss = mean_val_loss
            if int(epochs) == epoch_size:
                break
    model.save_parameters(os.path.join(ckpt_dir, 'ckpt.params'))
    trainer.save_states(os.path.join(ckpt_dir, 'trainer.status'))
    f.write('Training finished\n')