In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import torch
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('.')

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random
import argparse
from collections import deque
import cPickle as pickle

from fast_jtnn import *
import rdkit

In [2]:
load_epoch = 0

alpha = 0.1 * 7000 #0.1 * Number of samples

hidden_size=450
batch_size=8
latent_size=56
depthT=20
depthG=3
y_size=2

lr=1e-3
clip_norm=50.0
beta=0.0
step_beta=0.002
max_beta=1.0
warmup=40000

epoch=50
anneal_rate=0.9
anneal_iter=40000
kl_anneal_iter=2000
save_iter=5000
print_iter=100

test_epoch=1

train_folder = 'fast_molvae/moses-processed/train'
test_folder = 'fast_molvae/moses-processed/test'
vocab_file = 'data/tox21/vocab.txt'

In [3]:
vocab = [x.strip("\r\n ") for x in open(vocab_file)]
vocab = Vocab(vocab)

In [4]:
model = JTNNVAE(vocab, hidden_size, latent_size, y_size, depthT, depthG, alpha).cuda()
print model



JTNNVAE(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(550, 450)
    (outputNN): Sequential(
      (0): Linear(in_features=900, out_features=450, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=900, out_features=450, bias=True)
      (W_r): Linear(in_features=450, out_features=450, bias=False)
      (U_r): Linear(in_features=450, out_features=450, bias=True)
      (W_h): Linear(in_features=900, out_features=450, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(550, 450)
    (W_z): Linear(in_features=900, out_features=450, bias=True)
    (U_r): Linear(in_features=450, out_features=450, bias=False)
    (W_r): Linear(in_features=450, out_features=450, bias=True)
    (W_h): Linear(in_features=900, out_features=450, bias=True)
    (W): Linear(in_features=478, out_features=450, bias=True)
    (U): Linear(in_features=478, out_features=450, bias=True)
    (U_i): Linear(in_features=900, out_features=450, bias=True)
    (W_o): 

In [5]:
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)

In [6]:
if load_epoch > 0:
    model.load_state_dict(torch.load(save_dir + "/model.iter-" + str(load_epoch)))
    
print "Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)

Model #Params: 5368K


In [7]:
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, anneal_rate)
scheduler.step()



In [8]:
param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))

In [9]:
total_step = load_epoch
beta = beta
has_cuda = torch.cuda.is_available()

for epoch in xrange(epoch):
    
    train_loader = MolTreeFolder(train_folder, vocab, batch_size=batch_size)
    meters = np.zeros(5)
    
    model.train()
    
    for (supervised_batch, unsupervised_batch) in train_loader: 
        if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
            total_step += 1

            labels = torch.from_numpy(supervised_batch['labels'])[:, None].long()
            labels = torch.zeros(batch_size, 2).scatter_(1, labels, 1)
            if has_cuda: labels = labels.cuda()
            
            model.zero_grad()
            
            unsupervised_loss, kl_div1, wacc1, tacc1, sacc1, _ = model(unsupervised_batch['data'], None, beta)
            supervised_loss, kl_div2, wacc2, tacc2, sacc2, clsf_acc = model(supervised_batch['data'], labels, beta)

            loss = unsupervised_loss + supervised_loss
            
            kl_div = kl_div1 + kl_div2
            wacc = wacc1 + wacc2
            tacc = tacc1 + tacc2
            sacc = sacc1 + sacc2
        
            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            meters = meters + np.array([kl_div, wacc * 100, tacc * 100, sacc * 100, clsf_acc * 100])

            if total_step % save_iter == 0:
                torch.save(model.state_dict(), save_dir + "/model.iter-" + str(total_step))

            if total_step % anneal_iter == 0:
                scheduler.step()
                print "learning rate: %.6f" % scheduler.get_lr()[0]

            if total_step % kl_anneal_iter == 0 and total_step >= warmup:
                beta = min(max_beta, beta + step_beta)
        
            if total_step % print_iter == 0:
                meters /= print_iter
                print "Epoch: %d" % epoch
                print "[Train] Loss: %.3f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, Clsf: %.2f" % (loss, meters[0], meters[1], meters[2], meters[3], meters[4])
                sys.stdout.flush()
                meters *= 0
        
                
    if epoch % test_epoch == 0:
        with torch.no_grad():
            model.eval()
        
            test_loader = MolTreeFolder(test_folder, vocab, batch_size=batch_size)
    
            batch_count = 0
            total_acc = 0
            meters *= 0

            for (supervised_batch, _) in test_loader:
                batch_count += 1

                supervised_input = supervised_batch['data']
                
                target = torch.tensor(supervised_batch['labels']).type(torch.int32)
                target = target.cuda() if has_cuda else target
                
                labels = torch.from_numpy(supervised_batch['labels'])[:, None].long()
                labels = torch.zeros(batch_size, 2).scatter_(1, labels, 1)
                
                loss, kl_div, wacc, tacc, sacc, clsf_acc = model(supervised_input, labels, beta)
                meters = meters + np.array([kl_div, wacc * 100, tacc * 100, sacc * 100, clsf_acc * 100])
                
            meters /= batch_count
            print "[Train] Loss: %.3f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, Clsf: %.2f" % (loss, meters[0], meters[1], meters[2], meters[3], meters[4])
            meters *= 0



RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`