In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import torch
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('.')

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random
import argparse
from collections import deque
import cPickle as pickle

from fast_jtnn import *
import rdkit

In [2]:
label_col = 10 + 1 #add one because first col is smile
target_unlabeled_percentage = 0
early_stop_thresh = 2

load_epoch = 0
alpha = 50 #0.1 * len_unlabelled / len_labelled

hidden_size=450
batch_size=8
latent_size=56
depthT=20
depthG=3
y_size=2

lr=1e-3
clip_norm=50.0
beta=0.0
step_beta=0.002
max_beta=1.0
warmup=40 #40000

epoch=5
test_epoch=1
anneal_rate=0.9
anneal_iter=40 #40000
kl_anneal_iter=20 #2000

save_iter=50
print_iter=5

num_workers = 4
has_cuda = torch.cuda.is_available()

save_dir = 'data/tox21/model'
train_folder = 'fast_molvae/moses-processed/train'
test_folder = 'fast_molvae/moses-processed/test'
vocab_file = 'data/tox21/vocab.txt'

In [3]:
vocab = [x.strip("\r\n ") for x in open(vocab_file)]
vocab = Vocab(vocab)

In [4]:
weight = None #torch.tensor([1., 5.])

model = SEMIJTNNVAE(vocab, hidden_size, latent_size, y_size, depthT, depthG, alpha, weight=weight)
if has_cuda: model = model.cuda()
print model



JTNNVAE(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(550, 450)
    (outputNN): Sequential(
      (0): Linear(in_features=900, out_features=450, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=900, out_features=450, bias=True)
      (W_r): Linear(in_features=450, out_features=450, bias=False)
      (U_r): Linear(in_features=450, out_features=450, bias=True)
      (W_h): Linear(in_features=900, out_features=450, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(550, 450)
    (W_z): Linear(in_features=900, out_features=450, bias=True)
    (U_r): Linear(in_features=450, out_features=450, bias=False)
    (W_r): Linear(in_features=450, out_features=450, bias=True)
    (W_h): Linear(in_features=900, out_features=450, bias=True)
    (W): Linear(in_features=478, out_features=450, bias=True)
    (U): Linear(in_features=478, out_features=450, bias=True)
    (U_i): Linear(in_features=900, out_features=450, bias=True)
    (W_o): 

In [5]:
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)

In [6]:
if load_epoch > 0:
    model.load_state_dict(torch.load(save_dir + "/model.iter-" + str(load_epoch)))
    
print "Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)

Model #Params: 5368K


In [7]:
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, anneal_rate)

In [8]:
param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))

In [9]:
# Doing this because 20% of labelled data got lost in test
tracker = IndexTracker(train_folder, label_idx=11, label_pct=(1-target_unlabeled_percentage)/0.8)



In [None]:
from sklearn.metrics import classification_report
from utils import time_since
import traceback

total_step = load_epoch
beta = beta

train_meter_hist = []
val_meter_hist = []

def numpy_label_to_onehot_tensor(np_labels):
    labels = torch.from_numpy(np_labels)[:, None].long()
    labels = torch.zeros(batch_size, 2).scatter_(1, labels, 1)
    labels = labels.cuda() if has_cuda else labels
    return labels

start = time.time()
for epoch in xrange(epoch):
    
    meters = np.zeros(8) # loss, kl_div, wacc, tacc, sacc, clsf_acc, division_factor
    val_meter_hist.append(np.zeros(8))
    train_meter_hist.append(np.zeros(8))

    preds = np.array([])
    targets = np.array([])
    
    # Evaluation loop
    if epoch % test_epoch == 0:
        with torch.no_grad():
            model.eval()
        
            test_loader = MolTreeFolder(test_folder, vocab, batch_size=batch_size, label_idx=label_col, num_workers=num_workers)

            for (supervised_batch, _) in test_loader:  
                try:
                    if len(supervised_batch['labels']) == batch_size:
                    
                        supervised_input = supervised_batch['data']
                        labels = numpy_label_to_onehot_tensor(supervised_batch['labels'])

                        loss, clsf_loss, kl_div, wacc, tacc, sacc, clsf_acc, (pred, target) = model(supervised_input, labels, beta)

                        preds = np.append(preds, pred.cpu().detach().numpy())
                        targets = np.append(targets, target.cpu().detach().numpy())

                        meters = meters + np.array([loss, clsf_loss, kl_div, wacc * 100, tacc * 100, sacc * 100, clsf_acc * 100, 1.], dtype=np.float32)

                except Exception as e:
                    traceback.print_exc()
                    continue

            print(classification_report(targets, preds)) 
            
            if meters[-1] > 0:
                meters /= meters[-1]
                print 'time: %s' % time_since(start)
                print "[Test] Loss: %.3f, Clsf_loss: %.2f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, Clsf: %.2f" % (loss, meters[1], meters[2], meters[3], meters[4], meters[5], meters[6])
                val_meter_hist[-1] = meters
                meters *= 0
            else:
                val_meter_hist[-1] = np.full_like(val_meter_hist[-1], np.nan, dtype=np.double)
                    
                
    train_loader = MolTreeFolder(train_folder, vocab, batch_size=batch_size, label_idx=label_col, num_workers=num_workers, index_tracker=tracker)    
    meters *= 0
    model.train()
    
    for (supervised_batch, unsupervised_batch) in train_loader: 
        try:
            if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
                model.zero_grad()
                
                total_step += 1
                labels = numpy_label_to_onehot_tensor(supervised_batch['labels'])

                unsupervised_loss, _, kl_div1, wacc1, tacc1, sacc1, _, _ = model(unsupervised_batch['data'], None, beta)
                supervised_loss, clsf_loss, kl_div2, wacc2, tacc2, sacc2, clsf_acc, (pred, target) = model(supervised_batch['data'], labels, beta)

                # print(classification_report(target.cpu().detach().numpy(), pred.cpu().detach().numpy())) 
                
                loss = unsupervised_loss + supervised_loss

                kl_div = kl_div1 + kl_div2
                wacc = (wacc1 + wacc2)/2
                tacc = (tacc1 + tacc2)/2
                sacc = (sacc1 + sacc2)/2

                meters = meters + np.array([loss, clsf_loss, kl_div, wacc * 100, tacc * 100, sacc * 100, clsf_acc * 100, 1], dtype=np.float32)

                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
                optimizer.step()
                optimizer.zero_grad()

                if total_step % save_iter == 0:
                    print "Saving model to " + save_dir + "/model.iter-" + "at step: " + str(total_step)
                    torch.save(model.state_dict(), save_dir + "/model.iter-" + str(total_step))

                if total_step % anneal_iter == 0:
                    scheduler.step()
                    print "learning rate: %.6f" % scheduler.get_lr()[0]

                if total_step % kl_anneal_iter == 0 and total_step >= warmup:
                    beta = min(max_beta, beta + step_beta)

                if total_step % print_iter == 0:
                    meters /= meters[-1]
                    
                    print "Epoch: %d | Iter: %d" % (epoch, total_step)
                    print "[Train] Loss: %.3f, Clsf_loss: %.2f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, Clsf: %.2f" % (meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], meters[6])
                    sys.stdout.flush()

                    # Cache metrics for plotting
                    train_meter_hist[-1] += meters
                    meters *= 0

        except Exception as e:
            traceback.print_exc()
            continue
            
    scheduler.step()
    train_meter_hist[-1] = train_meter_hist[-1] / train_meter_hist[-1][-1]
    
    improved = True
    if len(train_meter_hist) > early_stop_thresh+1:
        improved = False
        baseline_loss = train_meter_hist[early_stop_thresh*-1-1][0]
        for i in range(early_stop_thresh):
            if baseline_loss > train_meter_hist[i*-1][0]:
                improved = True
                
    if not improved:
        print("Stopping early due to no improvement")
        print(classification_report(actuals, preds))
        break
    else:
        continue
            

  input = module(input)


tensor([[0.5717, 0.4283],
        [0.5324, 0.4676],
        [0.5813, 0.4187],
        [0.5728, 0.4272],
        [0.5867, 0.4133],
        [0.5354, 0.4646],
        [0.5627, 0.4373],
        [0.5624, 0.4376]], device='cuda:0')
tensor([0, 0, 1, 1, 0, 1, 0, 1], device='cuda:0')
tensor(2.2116, device='cuda:0')
tensor([[0.5642, 0.4358],
        [0.5113, 0.4887],
        [0.5676, 0.4324],
        [0.5376, 0.4624],
        [0.5639, 0.4361],
        [0.5226, 0.4774],
        [0.5373, 0.4627],
        [0.5461, 0.4539]], device='cuda:0')
tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor(0.6505, device='cuda:0')
tensor([[0.5622, 0.4378],
        [0.5496, 0.4504],
        [0.5518, 0.4482],
        [0.5699, 0.4301],
        [0.5424, 0.4576],
        [0.5783, 0.4217],
        [0.5468, 0.4532],
        [0.5536, 0.4464]], device='cuda:0')
tensor([0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tensor(1.0203, device='cuda:0')
tensor([[0.5320, 0.4680],
        [0.5511, 0.4489],
        [0.5404, 0.4596]

tensor([[0.5436, 0.4564],
        [0.5582, 0.4418],
        [0.5626, 0.4374],
        [0.5610, 0.4390],
        [0.5455, 0.4545],
        [0.5579, 0.4421],
        [0.5331, 0.4669],
        [0.5234, 0.4766]], device='cuda:0')
tensor([0, 0, 0, 0, 0, 0, 1, 0], device='cuda:0')
tensor(1.0179, device='cuda:0')
tensor([[0.5411, 0.4589],
        [0.5386, 0.4614],
        [0.5467, 0.4533],
        [0.5532, 0.4468],
        [0.5401, 0.4599],
        [0.5157, 0.4843],
        [0.5631, 0.4369],
        [0.5553, 0.4447]], device='cuda:0')
tensor([0, 0, 1, 0, 0, 1, 0, 0], device='cuda:0')
tensor(1.3905, device='cuda:0')


Traceback (most recent call last):
  File "<ipython-input-10-d3c5ca4ed0e5>", line 40, in <module>
    loss, clsf_loss, kl_div, wacc, tacc, sacc, clsf_acc, (pred, target) = model(supervised_input, labels, beta)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_vae.py", line 188, in forward
    compute_lxy(y_batch, x_batch, x_tree_vecs, x_tree_mess, x_mol_vecs, x_jtmpn_holder)
  File "fast_jtnn/jtnn_vae.py", line 137, in compute_lxy
    word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_dec.py", line 61, in forward
    dfs(s, mol_tree.nodes[0], -1)
IndexError: list index out of range


tensor([[0.5594, 0.4406],
        [0.5610, 0.4390],
        [0.5410, 0.4590],
        [0.5753, 0.4247],
        [0.5442, 0.4558],
        [0.5914, 0.4086],
        [0.5487, 0.4513],
        [0.5369, 0.4631]], device='cuda:0')
tensor([1, 0, 0, 1, 0, 1, 0, 0], device='cuda:0')
tensor(1.8514, device='cuda:0')
tensor([[0.5690, 0.4310],
        [0.5459, 0.4541],
        [0.5603, 0.4397],
        [0.5702, 0.4298],
        [0.5720, 0.4280],
        [0.5639, 0.4361],
        [0.5450, 0.4550],
        [0.5759, 0.4241]], device='cuda:0')
tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor(0.6324, device='cuda:0')
tensor([[0.5562, 0.4438],
        [0.5549, 0.4451],
        [0.5442, 0.4558],
        [0.5420, 0.4580],
        [0.5675, 0.4325],
        [0.5524, 0.4476],
        [0.5428, 0.4572],
        [0.5644, 0.4356]], device='cuda:0')
tensor([1, 0, 0, 0, 0, 0, 0, 1], device='cuda:0')
tensor(1.4270, device='cuda:0')
tensor([[0.5315, 0.4685],
        [0.5498, 0.4502],
        [0.5658, 0.4342]

tensor([[0.5496, 0.4504],
        [0.5875, 0.4125],
        [0.5393, 0.4607],
        [0.5610, 0.4390],
        [0.5568, 0.4432],
        [0.5482, 0.4518],
        [0.5335, 0.4665],
        [0.5518, 0.4482]], device='cuda:0')
tensor([0, 0, 0, 1, 0, 0, 0, 0], device='cuda:0')
tensor(1.0345, device='cuda:0')
tensor([[0.5483, 0.4517],
        [0.5567, 0.4433],
        [0.5598, 0.4402],
        [0.5681, 0.4319],
        [0.5743, 0.4257],
        [0.5871, 0.4129],
        [0.5521, 0.4479],
        [0.5502, 0.4498]], device='cuda:0')
tensor([0, 0, 0, 0, 0, 1, 1, 0], device='cuda:0')
tensor(1.4332, device='cuda:0')
tensor([[0.5538, 0.4462],
        [0.5303, 0.4697],
        [0.5512, 0.4488],
        [0.5768, 0.4232],
        [0.5184, 0.4816],
        [0.5637, 0.4363],
        [0.5401, 0.4599],
        [0.5763, 0.4237]], device='cuda:0')
tensor([0, 1, 0, 0, 0, 0, 0, 0], device='cuda:0')
tensor(1.0129, device='cuda:0')
tensor([[0.5245, 0.4755],
        [0.5717, 0.4283],
        [0.5551, 0.4449]

tensor([[0.5644, 0.4356],
        [0.5751, 0.4249],
        [0.5325, 0.4675],
        [0.5474, 0.4526],
        [0.5533, 0.4467],
        [0.5525, 0.4475],
        [0.5351, 0.4649],
        [0.5798, 0.4202]], device='cuda:0')
tensor([0, 1, 0, 0, 1, 0, 1, 0], device='cuda:0')
tensor(1.8045, device='cuda:0')
tensor([[0.5308, 0.4692],
        [0.5347, 0.4653],
        [0.5440, 0.4560],
        [0.5591, 0.4409],
        [0.5756, 0.4244],
        [0.5610, 0.4390],
        [0.5840, 0.4160],
        [0.5697, 0.4303]], device='cuda:0')
tensor([0, 0, 0, 1, 1, 0, 0, 0], device='cuda:0')
tensor(1.4341, device='cuda:0')
tensor([[0.5564, 0.4436],
        [0.5143, 0.4857],
        [0.5416, 0.4584],
        [0.5697, 0.4303],
        [0.5385, 0.4615],
        [0.5691, 0.4309],
        [0.5909, 0.4091],
        [0.5567, 0.4433]], device='cuda:0')
tensor([0, 0, 0, 0, 1, 0, 0, 0], device='cuda:0')
tensor(1.0161, device='cuda:0')
tensor([[0.5472, 0.4528],
        [0.5335, 0.4665],
        [0.5790, 0.4210]

tensor([[0.5556, 0.4444],
        [0.5492, 0.4508],
        [0.5161, 0.4839],
        [0.5638, 0.4362],
        [0.5496, 0.4504],
        [0.5479, 0.4521],
        [0.5673, 0.4327],
        [0.5415, 0.4585]], device='cuda:0')
tensor([0, 0, 1, 0, 0, 1, 0, 0], device='cuda:0')
tensor(1.3873, device='cuda:0')
tensor([[0.5448, 0.4552],
        [0.5782, 0.4218],
        [0.5527, 0.4473],
        [0.5506, 0.4494],
        [0.5242, 0.4758],
        [0.5469, 0.4531],
        [0.5462, 0.4538],
        [0.5730, 0.4270]], device='cuda:0')
tensor([0, 0, 1, 0, 0, 0, 0, 0], device='cuda:0')
tensor(1.0293, device='cuda:0')
tensor([[0.5585, 0.4415],
        [0.5177, 0.4823],
        [0.5705, 0.4295],
        [0.5473, 0.4527],
        [0.5632, 0.4368],
        [0.5390, 0.4610],
        [0.5460, 0.4540],
        [0.5295, 0.4705]], device='cuda:0')
tensor([0, 0, 0, 0, 1, 0, 0, 0], device='cuda:0')
tensor(1.0429, device='cuda:0')


Traceback (most recent call last):
  File "<ipython-input-10-d3c5ca4ed0e5>", line 40, in <module>
    loss, clsf_loss, kl_div, wacc, tacc, sacc, clsf_acc, (pred, target) = model(supervised_input, labels, beta)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_vae.py", line 188, in forward
    compute_lxy(y_batch, x_batch, x_tree_vecs, x_tree_mess, x_mol_vecs, x_jtmpn_holder)
  File "fast_jtnn/jtnn_vae.py", line 137, in compute_lxy
    word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_dec.py", line 108, in forward
    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
RuntimeError: shape '[-1, 15, 450]' is inval

tensor([[0.5463, 0.4537],
        [0.5540, 0.4460],
        [0.5634, 0.4366],
        [0.5469, 0.4531],
        [0.5612, 0.4388],
        [0.5757, 0.4243],
        [0.5514, 0.4486],
        [0.5527, 0.4473]], device='cuda:0')
tensor([0, 0, 0, 0, 1, 1, 0, 0], device='cuda:0')
tensor(1.4365, device='cuda:0')
tensor([[0.5401, 0.4599],
        [0.5322, 0.4678],
        [0.5299, 0.4701],
        [0.5572, 0.4428],
        [0.5498, 0.4502],
        [0.5666, 0.4334],
        [0.5629, 0.4371],
        [0.5525, 0.4475]], device='cuda:0')
tensor([0, 0, 0, 0, 0, 0, 1, 1], device='cuda:0')
tensor(1.4270, device='cuda:0')
tensor([[0.5825, 0.4175],
        [0.5697, 0.4303],
        [0.5556, 0.4444],
        [0.5101, 0.4899],
        [0.5458, 0.4542],
        [0.5290, 0.4710],
        [0.5324, 0.4676],
        [0.5626, 0.4374]], device='cuda:0')
tensor([0, 0, 1, 0, 0, 0, 0, 1], device='cuda:0')
tensor(1.4296, device='cuda:0')
tensor([[0.5425, 0.4575],
        [0.5435, 0.4565],
        [0.5819, 0.4181]

              precision    recall  f1-score   support

         0.0       0.85      0.99      0.91       975
         1.0       0.00      0.00      0.00       177

   micro avg       0.84      0.84      0.84      1152
   macro avg       0.42      0.50      0.46      1152
weighted avg       0.72      0.84      0.77      1152

[Test] Loss: 108.501, Clsf_loss: 56.20, KL: 3.16, Word: 0.10, Topo: 48.25, Assm: 31.89, Clsf: 84.11


In [None]:
# import matplotlib
# import matplotlib.pyplot as plt
# import numpy as np

# # Data for plotting
# t = np.arange(len(train_meter_hist))

# plt.plot(t, np.array(train_meter_hist)[:, 0], label='train loss')
# plt.plot(t, np.array(val_meter_hist)[:, 0], label='validation loss')

# plt.title('Loss vs Epoch')
# plt.legend()
# plt.show()