In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import torch
import torch.nn as nn
from torch.nn import Parameter
import sys
sys.path.append('.')

from sklearn.metrics import classification_report
from utils import time_since
import traceback

import time

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.autograd import Variable

import math, random
import argparse
from collections import deque
import cPickle as pickle

from fast_jtnn import *
import rdkit

In [2]:
label_col = 10 + 1 #add one because first col is smile
target_unlabeled_percentage = 0.95
early_stop_thresh = 2

load_epoch = 0
alpha = 100 #0.1 * len_unlabelled / len_labelled

hidden_size=56 #450
batch_size=8
latent_size=28
depthT=20
depthG=3
y_size=2

lr=1e-3
clip_norm=50.0
beta=0.0
step_beta=0.002
max_beta=1.0
warmup=40 #40000

epoch=5
test_epoch=1
anneal_rate=0.9
anneal_iter=40 #40000
kl_anneal_iter=20 #2000

save_iter=50
print_iter=5

num_workers = 4
has_cuda = torch.cuda.is_available()

save_dir = 'data/tox21/model'
train_folder = 'fast_molvae/moses-processed/quick_train'
test_folder = 'fast_molvae/moses-processed/quick_test'
vocab_file = 'data/tox21/vocab.txt'

In [3]:
vocab = [x.strip("\r\n ") for x in open(vocab_file)]
vocab = Vocab(vocab)

In [4]:
weight = None #torch.tensor([1., 5.])

model = SemiJTNNVAEClassifier(vocab, hidden_size, latent_size, y_size, depthT, depthG, alpha, weight=weight)
if has_cuda: model = model.cuda()
print model



SemiJTNNVAEClassifier(
  (jtnn): JTNNEncoder(
    (embedding): Embedding(550, 56)
    (outputNN): Sequential(
      (0): Linear(in_features=112, out_features=56, bias=True)
      (1): ReLU()
    )
    (GRU): GraphGRU(
      (W_z): Linear(in_features=112, out_features=56, bias=True)
      (W_r): Linear(in_features=56, out_features=56, bias=False)
      (U_r): Linear(in_features=56, out_features=56, bias=True)
      (W_h): Linear(in_features=112, out_features=56, bias=True)
    )
  )
  (decoder): JTNNDecoder(
    (embedding): Embedding(550, 56)
    (W_z): Linear(in_features=112, out_features=56, bias=True)
    (U_r): Linear(in_features=56, out_features=56, bias=False)
    (W_r): Linear(in_features=56, out_features=56, bias=True)
    (W_h): Linear(in_features=112, out_features=56, bias=True)
    (W): Linear(in_features=70, out_features=56, bias=True)
    (U): Linear(in_features=70, out_features=56, bias=True)
    (U_i): Linear(in_features=112, out_features=56, bias=True)
    (W_o): Linear

In [5]:
for param in model.parameters():
    if param.dim() == 1:
        nn.init.constant_(param, 0)
    else:
        nn.init.xavier_normal_(param)

In [6]:
if load_epoch > 0:
    model.load_state_dict(torch.load(save_dir + "/model.iter-" + str(load_epoch)))
    
print "Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)

Model #Params: 183K


In [7]:
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, anneal_rate)

In [8]:
param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))

In [9]:
# Doing this because 20% of labelled data got lost in test
tracker = IndexTracker(train_folder, label_idx=11, label_pct=(1-target_unlabeled_percentage)/0.8)

In [None]:
total_step = load_epoch
beta = beta

train_meter_hist = []
val_meter_hist = []

def numpy_label_to_onehot_tensor(np_labels):    
    labels = torch.from_numpy(np_labels)[:, None].long()
    labels = torch.zeros(batch_size, 2).scatter_(1, labels, 1)
    labels = labels.cuda() if has_cuda else labels
    return labels

start = time.time()
for epoch in xrange(epoch):
    
    meters = np.zeros(8) # loss, kl_div, wacc, tacc, sacc, clsf_acc, division_factor
    val_meter_hist.append(np.zeros(8))
    train_meter_hist.append(np.zeros(8))

    preds = np.array([])
    targets = np.array([])
    
    # Evaluation loop
    if epoch % test_epoch == 0 and epoch > 0:
        with torch.no_grad():
            model.eval()
        
            test_loader = MolTreeFolder(test_folder, vocab, batch_size=batch_size, label_idx=label_col, num_workers=num_workers)

            for (supervised_batch, unsupervised_batch) in test_loader:  
                try:
                    if len(supervised_batch['labels']) == batch_size:
                    
                        supervised_input = supervised_batch['data']
                        labels = numpy_label_to_onehot_tensor(supervised_batch['labels'])
                        
                        loss, clsf_loss, kl_div, wacc, tacc, sacc, clsf_acc, (pred, target) = model(supervised_input, labels, beta)
                        
                        preds = np.append(preds, pred.cpu().detach().numpy())
                        targets = np.append(targets, target.cpu().detach().numpy())

                        meters = meters + np.array([loss, clsf_loss, kl_div, wacc * 100, tacc * 100, sacc * 100, clsf_acc * 100, 1.], dtype=np.float32)

                except Exception as e:
                    traceback.print_exc()
                    continue

            print(classification_report(targets, preds)) 
            
            if meters[-1] > 0:
                meters /= meters[-1]
                print 'time: %s' % time_since(start)
                print "[Test] Loss: %.3f, Clsf_loss: %.2f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, Clsf: %.2f" % (loss, meters[1], meters[2], meters[3], meters[4], meters[5], meters[6])
                val_meter_hist[-1] = meters
                meters *= 0
            else:
                val_meter_hist[-1] = np.full_like(val_meter_hist[-1], np.nan, dtype=np.double)
                    
                
    train_loader = MolTreeFolder(train_folder, vocab, batch_size=batch_size, label_idx=label_col, num_workers=num_workers, index_tracker=tracker)    
    meters *= 0
    model.train()
    
    for (supervised_batch, unsupervised_batch) in train_loader: 
        try:
            if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
                model.zero_grad()
                
                total_step += 1
                labels = numpy_label_to_onehot_tensor(supervised_batch['labels'])
                
                unsupervised_loss, _, kl_div1, wacc1, tacc1, sacc1, _, _ = model(unsupervised_batch['data'], None, beta)
                supervised_loss, clsf_loss, kl_div2, wacc2, tacc2, sacc2, clsf_acc, (pred, target) = model(supervised_batch['data'], labels, beta)

                # print(classification_report(target.cpu().detach().numpy(), pred.cpu().detach().numpy())) 
                
                loss = unsupervised_loss + supervised_loss

                kl_div = kl_div1 + kl_div2
                wacc = (wacc1 + wacc2)/2
                tacc = (tacc1 + tacc2)/2
                sacc = (sacc1 + sacc2)/2

                meters = meters + np.array([loss, clsf_loss, kl_div, wacc * 100, tacc * 100, sacc * 100, clsf_acc * 100, 1], dtype=np.float32)

                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
                optimizer.step()
                optimizer.zero_grad()

                if total_step % save_iter == 0:
                    print "Saving model to " + save_dir + "/model.iter-" + "at step: " + str(total_step)
                    torch.save(model.state_dict(), save_dir + "/model.iter-" + str(total_step))

                if total_step % anneal_iter == 0:
                    scheduler.step()
                    print "learning rate: %.6f" % scheduler.get_lr()[0]

                if total_step % kl_anneal_iter == 0 and total_step >= warmup:
                    beta = min(max_beta, beta + step_beta)

                if total_step % print_iter == 0:
                    meters /= meters[-1]
                    
                    print "Epoch: %d | Iter: %d" % (epoch, total_step)
                    print "[Train] Loss: %.3f, Clsf_loss: %.2f, KL: %.2f, Word: %.2f, Topo: %.2f, Assm: %.2f, Clsf: %.2f" % (meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], meters[6])
                    sys.stdout.flush()

                    # Cache metrics for plotting
                    train_meter_hist[-1] += meters
                    meters *= 0

        except Exception as e:
            traceback.print_exc()
            continue
            
    scheduler.step()
    train_meter_hist[-1] = train_meter_hist[-1] / train_meter_hist[-1][-1]
    
    improved = True
    if len(train_meter_hist) > early_stop_thresh+1:
        improved = False
        baseline_loss = train_meter_hist[early_stop_thresh*-1-1][0]
        for i in range(early_stop_thresh):
            if baseline_loss > train_meter_hist[i*-1][0]:
                improved = True
                
    if not improved:
        print("Stopping early due to no improvement")
        print(classification_report(actuals, preds))
        break
    else:
        continue
            

  input = module(input)
Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 72, in <module>
    unsupervised_loss, _, kl_div1, wacc1, tacc1, sacc1, _, _ = model(unsupervised_batch['data'], None, beta)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_vae.py", line 341, in forward
    self._compute_partial_loss(y_batch, x_batch, x_tree_vecs, x_tree_mess, x_mol_vecs, x_jtmpn_holder, beta)
  File "fast_jtnn/jtnn_vae.py", line 276, in _compute_partial_loss
    word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_dec.py", line 61, in forward
    dfs(s, mol_tree.nodes[0], -1)
IndexError: list index out of range


Epoch: 0 | Iter: 5
[Train] Loss: 229.737, Clsf_loss: 65.68, KL: 5.02, Word: 0.09, Topo: 52.33, Assm: 50.98, Clsf: 68.75
Epoch: 0 | Iter: 10
[Train] Loss: 241.237, Clsf_loss: 57.53, KL: 6.68, Word: 0.32, Topo: 52.96, Assm: 43.35, Clsf: 80.00
Epoch: 0 | Iter: 15
[Train] Loss: 237.706, Clsf_loss: 51.88, KL: 10.26, Word: 5.42, Topo: 54.29, Assm: 58.52, Clsf: 82.50
Epoch: 0 | Iter: 20
[Train] Loss: 236.116, Clsf_loss: 50.05, KL: 14.76, Word: 10.77, Topo: 59.28, Assm: 60.20, Clsf: 82.50
Epoch: 0 | Iter: 25
[Train] Loss: 232.209, Clsf_loss: 47.34, KL: 18.24, Word: 12.72, Topo: 60.94, Assm: 56.67, Clsf: 85.00
Epoch: 0 | Iter: 30
[Train] Loss: 231.531, Clsf_loss: 45.76, KL: 19.29, Word: 15.05, Topo: 62.27, Assm: 59.62, Clsf: 87.50
Epoch: 0 | Iter: 35
[Train] Loss: 200.054, Clsf_loss: 50.68, KL: 19.40, Word: 16.04, Topo: 67.93, Assm: 66.59, Clsf: 82.50
learning rate: 0.000900
Epoch: 0 | Iter: 40
[Train] Loss: 178.385, Clsf_loss: 51.89, KL: 20.33, Word: 18.70, Topo: 68.53, Assm: 69.73, Clsf: 80.0

Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 72, in <module>
    unsupervised_loss, _, kl_div1, wacc1, tacc1, sacc1, _, _ = model(unsupervised_batch['data'], None, beta)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_vae.py", line 341, in forward
    self._compute_partial_loss(y_batch, x_batch, x_tree_vecs, x_tree_mess, x_mol_vecs, x_jtmpn_holder, beta)
  File "fast_jtnn/jtnn_vae.py", line 276, in _compute_partial_loss
    word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_dec.py", line 108, in forward
    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
RuntimeError: sh

Saving model to data/tox21/model/model.iter-at step: 50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 93, in <module>
    torch.save(model.state_dict(), save_dir + "/model.iter-" + str(total_step))
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/serialization.py", line 260, in save
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/serialization.py", line 183, in _with_file_like
    f = open(f, mode)
IOError: [Errno 2] No such file or directory: 'data/tox21/model/model.iter-50'


Epoch: 0 | Iter: 55
[Train] Loss: 155.421, Clsf_loss: 47.02, KL: 27.81, Word: 27.28, Topo: 72.18, Assm: 67.38, Clsf: 84.72
Epoch: 0 | Iter: 60
[Train] Loss: 145.795, Clsf_loss: 44.82, KL: 42.24, Word: 30.25, Topo: 74.33, Assm: 68.83, Clsf: 90.00
Epoch: 0 | Iter: 65
[Train] Loss: 136.747, Clsf_loss: 45.96, KL: 50.12, Word: 28.89, Topo: 75.78, Assm: 76.48, Clsf: 87.50
Epoch: 0 | Iter: 70
[Train] Loss: 141.088, Clsf_loss: 45.49, KL: 69.49, Word: 29.35, Topo: 77.89, Assm: 76.26, Clsf: 87.50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 72, in <module>
    unsupervised_loss, _, kl_div1, wacc1, tacc1, sacc1, _, _ = model(unsupervised_batch['data'], None, beta)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_vae.py", line 341, in forward
    self._compute_partial_loss(y_batch, x_batch, x_tree_vecs, x_tree_mess, x_mol_vecs, x_jtmpn_holder, beta)
  File "fast_jtnn/jtnn_vae.py", line 276, in _compute_partial_loss
    word_loss, topo_loss, word_acc, topo_acc = self.decoder(x_batch, z_tree_vecs)
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "fast_jtnn/jtnn_dec.py", line 108, in forward
    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(-1,MAX_NB,self.hidden_size)
RuntimeError: sh

6


Traceback (most recent call last):
  File "fast_jtnn/datautils.py", line 217, in __getitem__
    'data': tensorize(self.data[idx], self.vocab, assm=self.assm),
  File "fast_jtnn/datautils.py", line 244, in tensorize
    jtmpn_holder = JTMPN.tensorize(cands, mess_dict)
  File "fast_jtnn/jtmpn.py", line 122, in tensorize
    fatoms = torch.stack(fatoms, 0)
RuntimeError: stack expects a non-empty TensorList


learning rate: 0.000810
Epoch: 0 | Iter: 80
[Train] Loss: 137.836, Clsf_loss: 42.47, KL: 84.85, Word: 30.15, Topo: 79.03, Assm: 71.70, Clsf: 91.67
Epoch: 0 | Iter: 85
[Train] Loss: 146.608, Clsf_loss: 50.34, KL: 101.43, Word: 32.64, Topo: 80.90, Assm: 70.02, Clsf: 82.50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 90
[Train] Loss: 155.279, Clsf_loss: 50.50, KL: 107.74, Word: 29.05, Topo: 80.20, Assm: 68.63, Clsf: 85.00


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 95
[Train] Loss: 139.767, Clsf_loss: 49.38, KL: 103.37, Word: 29.93, Topo: 82.18, Assm: 70.50, Clsf: 85.00


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Saving model to data/tox21/model/model.iter-at step: 100


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 93, in <module>
    torch.save(model.state_dict(), save_dir + "/model.iter-" + str(total_step))
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/serialization.py", line 260, in save
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/serialization.py", line 183, in _with_file_like
    f = open(f, mode)
IOError: [Errno 2] No such file or directory: 'data/tox21/model/model.iter-100'
Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 105
[Train] Loss: 137.298, Clsf_loss: 47.73, KL: 108.92, Word: 29.75, Topo: 80.98, Assm: 68.29, Clsf: 88.75


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 110
[Train] Loss: 125.618, Clsf_loss: 44.40, KL: 104.92, Word: 31.83, Topo: 84.88, Assm: 76.60, Clsf: 90.00
Epoch: 0 | Iter: 115
[Train] Loss: 126.708, Clsf_loss: 42.98, KL: 114.07, Word: 31.72, Topo: 84.01, Assm: 68.92, Clsf: 92.50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


learning rate: 0.000729
Epoch: 0 | Iter: 120
[Train] Loss: 133.173, Clsf_loss: 45.43, KL: 122.60, Word: 31.42, Topo: 83.55, Assm: 74.82, Clsf: 87.50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 125
[Train] Loss: 122.145, Clsf_loss: 43.19, KL: 123.60, Word: 32.55, Topo: 83.31, Assm: 81.00, Clsf: 90.00


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 130
[Train] Loss: 135.573, Clsf_loss: 42.90, KL: 141.11, Word: 29.22, Topo: 85.06, Assm: 71.51, Clsf: 92.50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 135
[Train] Loss: 129.896, Clsf_loss: 43.51, KL: 130.14, Word: 31.63, Topo: 84.82, Assm: 76.57, Clsf: 92.50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 140
[Train] Loss: 122.580, Clsf_loss: 39.50, KL: 125.90, Word: 31.87, Topo: 85.73, Assm: 77.69, Clsf: 95.00
Epoch: 0 | Iter: 145
[Train] Loss: 120.767, Clsf_loss: 37.93, KL: 123.42, Word: 30.26, Topo: 85.11, Assm: 71.25, Clsf: 97.50


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Saving model to data/tox21/model/model.iter-at step: 150


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 93, in <module>
    torch.save(model.state_dict(), save_dir + "/model.iter-" + str(total_step))
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/serialization.py", line 260, in save
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/home/tsa87/miniconda3/envs/python2/lib/python2.7/site-packages/torch/serialization.py", line 183, in _with_file_like
    f = open(f, mode)
IOError: [Errno 2] No such file or directory: 'data/tox21/model/model.iter-150'
Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


Epoch: 0 | Iter: 155
[Train] Loss: 125.857, Clsf_loss: 39.65, KL: 125.54, Word: 31.35, Topo: 84.67, Assm: 73.01, Clsf: 95.00


Traceback (most recent call last):
  File "<ipython-input-10-997b423b499f>", line 66, in <module>
    if len(supervised_batch['labels']) == batch_size and len(unsupervised_batch['labels']) == batch_size:
TypeError: 'NoneType' object has no attribute '__getitem__'


In [None]:
# import matplotlib
# import matplotlib.pyplot as plt
# import numpy as np

# # Data for plotting
# t = np.arange(len(train_meter_hist))

# plt.plot(t, np.array(train_meter_hist)[:, 0], label='train loss')
# plt.plot(t, np.array(val_meter_hist)[:, 0], label='validation loss')

# plt.title('Loss vs Epoch')
# plt.legend()
# plt.show()