In [1]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import torch
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('.')
sys.path.append('./probtorch')

import probtorch
from probtorch.util import expand_inputs
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random
import argparse
from collections import deque
import cPickle as pickle

from fast_jtnn import *
import rdkit

('probtorch:', '0.0+4b7e140', 'torch:', '1.3.1', 'cuda:', False)


In [2]:
load_epoch = 0

alpha = 0.1 * 500 #0.1 * Number of samples

hidden_size=450
batch_size=4
latent_size=56
depthT=20
depthG=3
y_size=2

lr=1e-3
clip_norm=50.0
beta=0.0
step_beta=0.002
max_beta=1.0
warmup=40000

epoch=20
anneal_rate=0.9
anneal_iter=40000
kl_anneal_iter=2000
print_iter=50
save_iter=5000

train_folder = 'fast_molvae/moses-processed/'
vocab_file = 'data/tox21/vocab.txt'

In [3]:
vocab = [x.strip("\r\n ") for x in open(vocab_file)]
vocab = Vocab(vocab)

In [4]:
model = JTNNVAE(vocab, hidden_size, latent_size, y_size, depthT, depthG, alpha)
print model

JTNNVAE(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(550, 450)
    (outputNN): Sequential(
      (0): Linear(in_features=900, out_features=450, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=900, out_features=450, bias=True)
      (W_r): Linear(in_features=450, out_features=450, bias=False)
      (U_r): Linear(in_features=450, out_features=450, bias=True)
      (W_h): Linear(in_features=900, out_features=450, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(550, 450)
    (W_z): Linear(in_features=900, out_features=450, bias=True)
    (U_r): Linear(in_features=450, out_features=450, bias=False)
    (W_r): Linear(in_features=450, out_features=450, bias=True)
    (W_h): Linear(in_features=900, out_features=450, bias=True)
    (W): Linear(in_features=478, out_features=450, bias=True)
    (U): Linear(in_features=478, out_features=450, bias=True)
    (U_i): Linear(in_features=900, out_features=450, bias=True)
    (W_o): 



In [5]:
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)

In [6]:
if load_epoch > 0:
    model.load_state_dict(torch.load(save_dir + "/model.iter-" + str(load_epoch)))
    
print "Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)

Model #Params: 5368K


In [7]:
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, anneal_rate)
scheduler.step()



In [8]:
param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))

In [None]:
total_step = load_epoch
beta = beta
meters = np.zeros(4)

for epoch in xrange(epoch):
    loader = MolTreeFolder(train_folder, vocab, batch_size=batch_size)
    for (supervised_batch, unsupervised_batch) in loader: 
        if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
            total_step += 1

            labels = torch.from_numpy(supervised_batch['labels'])[:, None].long()
            labels = torch.zeros(batch_size, 2).scatter_(1, labels, 1)

            model.zero_grad()
            unsupervised_loss, kl_div, wacc, tacc, sacc = model(unsupervised_batch['data'], None, beta)
            supervised_loss, kl_div, wacc, tacc, sacc = model(supervised_batch['data'], labels, beta) 

            loss = unsupervised_loss + supervised_loss
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
            optimizer.step()

            meters = meters + np.array([kl_div, wacc * 100, tacc * 100, sacc * 100])

            if total_step % print_iter == 0:
                meters /= print_iter
                print "[%d] Beta: %.3f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], param_norm(model), grad_norm(model))
                sys.stdout.flush()
                meters *= 0

            if total_step % save_iter == 0:
                torch.save(model.state_dict(), save_dir + "/model.iter-" + str(total_step))

            if total_step % anneal_iter == 0:
                scheduler.step()
                print "learning rate: %.6f" % scheduler.get_lr()[0]

            if total_step % kl_anneal_iter == 0 and total_step >= warmup:
                beta = min(max_beta, beta + step_beta)

  input = module(input)


In [None]:
def test(model, infer=True):
    model.eval()

    epoch_elbo = 0.0
    epoch_correct = 0
    N = 0
    for b, (images, labels) in enumerate(data):
        if images.size()[0] == NUM_BATCH:
            N += NUM_BATCH
            images = images.view(-1, NUM_PIXELS)
            if CUDA:
                images = images.cuda()
            q = enc(images, num_samples=NUM_SAMPLES)
            p = dec(images, q, num_samples=NUM_SAMPLES)
            batch_elbo = elbo(q, p)
            if CUDA:
                batch_elbo = batch_elbo.cpu()
            epoch_elbo += batch_elbo.item()
            if infer:
                log_p = p.log_joint(0, 1)
                log_q = q.log_joint(0, 1)
                log_w = log_p - log_q
                w = torch.nn.functional.softmax(log_w, 0)
                y_samples = q['digits'].value
                y_expect = (w.unsqueeze(-1) * y_samples).sum(0)
                _ , y_pred = y_expect.max(-1)
                if CUDA:
                    y_pred = y_pred.cpu()
                epoch_correct += (labels == y_pred).sum().item()
            else:
                _, y_pred = q['digits'].value.max(-1)
                if CUDA:
                    y_pred = y_pred.cpu()
                epoch_correct += (labels == y_pred).sum().item() / (NUM_SAMPLES or 1.0)
    return epoch_elbo / N, epoch_correct / N