In [1]:
import argparse
import torch
import pickle 
import numpy as np 
import os 
import math 
import random 
import sys
import matplotlib.pyplot as plt 
import data
import scipy.io

from torch import nn, optim
from torch.nn import functional as F

from etm import ETM
from utils import nearest_neighbors, get_topic_coherence, get_topic_diversity

In [2]:
class Args:
    description='The Embedded Topic Model'
    dataset = 'ah20k'
    data_path = 'data/ah20k'
    emb_path = 'data/ah20k_embeddings.txt'
    save_path = './results'
    batch_size = 1000

    ### model-related arguments
    num_topics = 5
    rho_size = 300
    emb_size = 300
    t_hidden_size = 800
    theta_act = 'relu'
    train_embeddings = 1 #

    ### optimization-related arguments
    lr= 0.005
    lr_factor = 4.0
    epochs = 200
    mode = 'train'
    optimizer = 'adam'
    seed = 2019
    enc_drop = 0.0
    clip = 0.0
    nonmono = 10
    wdecay = 1.2e-6
    anneal_lr = 0
    bow_norm = 1

    ### evaluation, visualization, and logging-related arguments
    num_words = 10
    log_interval = 2
    visualize_every = 10
    eval_batch_size = 1000
    load_from = ''
    tc = True
    td = True

In [3]:
args = Args()

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)

## get data
# 1. vocabulary
vocab, train, valid, test = data.get_data(os.path.join(args.data_path))
vocab_size = len(vocab)
args.vocab_size = vocab_size

# 1. training data
train_tokens = train['tokens']
train_counts = train['counts']
args.num_docs_train = len(train_tokens)

# 2. dev set
valid_tokens = valid['tokens']
valid_counts = valid['counts']
args.num_docs_valid = len(valid_tokens)

# 3. test data
test_tokens = test['tokens']
test_counts = test['counts']
args.num_docs_test = len(test_tokens)
test_1_tokens = test['tokens_1']
test_1_counts = test['counts_1']
args.num_docs_test_1 = len(test_1_tokens)
test_2_tokens = test['tokens_2']
test_2_counts = test['counts_2']
args.num_docs_test_2 = len(test_2_tokens)

In [6]:
embeddings = None
if not args.train_embeddings:
    emb_path = args.emb_path
    vect_path = os.path.join(args.data_path.split('/')[0], 'embeddings.pkl')   
    vectors = {}
    with open(emb_path, 'rb') as f:
        for l in f:
            line = l.decode().split()
            word = line[0]
            if word in vocab:
                vect = np.array(line[1:]).astype(np.float)
                vectors[word] = vect
    embeddings = np.zeros((vocab_size, args.emb_size))
    words_found = 0
    for i, word in enumerate(vocab):
        try: 
            embeddings[i] = vectors[word]
            words_found += 1
        except KeyError:
            embeddings[i] = np.random.normal(scale=0.6, size=(args.emb_size, ))
    embeddings = torch.from_numpy(embeddings).to(device)
    args.embeddings_dim = embeddings.size()

print('=*'*100)
print('Training an Embedded Topic Model on {} with the following settings: {}'.format(args.dataset.upper(), args))
print('=*'*100)

=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*
Training an Embedded Topic Model on AH20K with the following settings: <__main__.Args object at 0x7f55e9d0b310>
=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*


In [7]:
## define checkpoint
if not os.path.exists(args.save_path):
    os.makedirs(args.save_path)

if args.mode == 'eval':
    ckpt = args.load_from
else:
    ckpt = os.path.join(args.save_path, 
        'etm_{}_K_{}_Htheta_{}_Optim_{}_Clip_{}_ThetaAct_{}_Lr_{}_Bsz_{}_RhoSize_{}_trainEmbeddings_{}'.format(
        args.dataset, args.num_topics, args.t_hidden_size, args.optimizer, args.clip, args.theta_act, 
            args.lr, args.batch_size, args.rho_size, args.train_embeddings))

## define model and optimizer
model = ETM(args.num_topics, vocab_size, args.t_hidden_size, args.rho_size, args.emb_size, 
                args.theta_act, embeddings, args.train_embeddings, args.enc_drop).to(device)

print('model: {}'.format(model))

if args.optimizer == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'adagrad':
    optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'adadelta':
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'rmsprop':
    optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wdecay)
elif args.optimizer == 'asgd':
    optimizer = optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay)
else:
    print('Defaulting to vanilla SGD')
    optimizer = optim.SGD(model.parameters(), lr=args.lr)

model: ETM(
  (t_drop): Dropout(p=0.0, inplace=False)
  (theta_act): ReLU()
  (rho): Linear(in_features=300, out_features=3951, bias=False)
  (alphas): Linear(in_features=300, out_features=5, bias=False)
  (q_theta): Sequential(
    (0): Linear(in_features=3951, out_features=800, bias=True)
    (1): ReLU()
    (2): Linear(in_features=800, out_features=800, bias=True)
    (3): ReLU()
  )
  (mu_q_theta): Linear(in_features=800, out_features=5, bias=True)
  (logsigma_q_theta): Linear(in_features=800, out_features=5, bias=True)
)


In [8]:
def train(epoch):
    model.train()
    acc_loss = 0
    acc_kl_theta_loss = 0
    cnt = 0
    indices = torch.randperm(args.num_docs_train)
    indices = torch.split(indices, args.batch_size)
    for idx, ind in enumerate(indices):
        optimizer.zero_grad()
        model.zero_grad()
        data_batch = data.get_batch(train_tokens, train_counts, ind, args.vocab_size, device)
        sums = data_batch.sum(1).unsqueeze(1)
        if args.bow_norm:
            normalized_data_batch = data_batch / sums
        else:
            normalized_data_batch = data_batch
        recon_loss, kld_theta = model(data_batch, normalized_data_batch)
        total_loss = recon_loss + kld_theta
        total_loss.backward()

        if args.clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        acc_loss += torch.sum(recon_loss).item()
        acc_kl_theta_loss += torch.sum(kld_theta).item()
        cnt += 1

        if idx % args.log_interval == 0 and idx > 0:
            cur_loss = round(acc_loss / cnt, 2) 
            cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
            cur_real_loss = round(cur_loss + cur_kl_theta, 2)

            print('Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format(
                epoch, idx, len(indices), optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss))
    
    cur_loss = round(acc_loss / cnt, 2) 
    cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 
    cur_real_loss = round(cur_loss + cur_kl_theta, 2)
    print('*'*100)
    print('Epoch----->{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format(
            epoch, optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss))
    print('*'*100)

def visualize(m, show_emb=True):
    if not os.path.exists('./results'):
        os.makedirs('./results')

    m.eval()

    queries = ['cleaner', 'refrigerate', 'tupperware', 'curry', 'baby', 'weather', 'buffet', 
                            'ninja', 'fingernail']

    ## visualize topics using monte carlo
    with torch.no_grad():
        print('#'*100)
        print('Visualize topics...')
        topics_words = []
        gammas = m.get_beta()
        for k in range(args.num_topics):
            gamma = gammas[k]
            top_words = list(gamma.cpu().numpy().argsort()[-args.num_words+1:][::-1])
            topic_words = [vocab[a] for a in top_words]
            topics_words.append(' '.join(topic_words))
            print('Topic {}: {}'.format(k, topic_words))

        if show_emb:
            ## visualize word embeddings by using V to get nearest neighbors
            print('#'*100)
            print('Visualize word embeddings by using output embedding matrix')
            try:
                embeddings = m.rho.weight  # Vocab_size x E
            except:
                embeddings = m.rho         # Vocab_size x E
            neighbors = []
            for word in queries:
                print('word: {} .. neighbors: {}'.format(
                    word, nearest_neighbors(word, embeddings, vocab)))
            print('#'*100)

def evaluate(m, source, tc=False, td=False):
    """Compute perplexity on document completion.
    """
    m.eval()
    with torch.no_grad():
        if source == 'val':
            indices = torch.split(torch.tensor(range(args.num_docs_valid)), args.eval_batch_size)
            tokens = valid_tokens
            counts = valid_counts
        else: 
            indices = torch.split(torch.tensor(range(args.num_docs_test)), args.eval_batch_size)
            tokens = test_tokens
            counts = test_counts

        ## get \beta here
        beta = m.get_beta()

        ### do dc and tc here
        acc_loss = 0
        cnt = 0
        indices_1 = torch.split(torch.tensor(range(args.num_docs_test_1)), args.eval_batch_size)
        for idx, ind in enumerate(indices_1):
            ## get theta from first half of docs
            data_batch_1 = data.get_batch(test_1_tokens, test_1_counts, ind, args.vocab_size, device)
            sums_1 = data_batch_1.sum(1).unsqueeze(1)
            if args.bow_norm:
                normalized_data_batch_1 = data_batch_1 / sums_1
            else:
                normalized_data_batch_1 = data_batch_1
            theta, _ = m.get_theta(normalized_data_batch_1)

            ## get prediction loss using second half
            data_batch_2 = data.get_batch(test_2_tokens, test_2_counts, ind, args.vocab_size, device)
            sums_2 = data_batch_2.sum(1).unsqueeze(1)
            res = torch.mm(theta, beta)
            preds = torch.log(res)
            recon_loss = -(preds * data_batch_2).sum(1)
            
            loss = recon_loss / sums_2.squeeze()
            loss = loss.mean().item()
            acc_loss += loss
            cnt += 1
        cur_loss = acc_loss / cnt
        ppl_dc = round(math.exp(cur_loss), 1)
        print('*'*100)
        print('{} Doc Completion PPL: {}'.format(source.upper(), ppl_dc))
        print('*'*100)
        if tc or td:
            beta = beta.data.cpu().numpy()
            if tc:
                print('Computing topic coherence...')
                get_topic_coherence(beta, train_tokens, vocab)
            if td:
                print('Computing topic diversity...')
                get_topic_diversity(beta, 25)
        return ppl_dc

In [None]:
if args.mode == 'train':
    ## train model on data 
    best_epoch = 0
    best_val_ppl = 1e9
    all_val_ppls = []
    print('\n')
    print('Visualizing model quality before training...')
    visualize(model)
    print('\n')
    for epoch in range(1, args.epochs):
        train(epoch)
        val_ppl = evaluate(model, 'val')
        if val_ppl < best_val_ppl:
            with open(ckpt, 'wb') as f:
                torch.save(model, f)
            best_epoch = epoch
            best_val_ppl = val_ppl
        else:
            ## check whether to anneal lr
            lr = optimizer.param_groups[0]['lr']
            if args.anneal_lr and (len(all_val_ppls) > args.nonmono and val_ppl > min(all_val_ppls[:-args.nonmono]) and lr > 1e-5):
                optimizer.param_groups[0]['lr'] /= args.lr_factor
        if epoch % args.visualize_every == 0:
            visualize(model)
        all_val_ppls.append(val_ppl)
    with open(ckpt, 'rb') as f:
        model = torch.load(f)
    model = model.to(device)
    val_ppl = evaluate(model, 'val')
else:   
    with open(ckpt, 'rb') as f:
        model = torch.load(f)
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        ## get document completion perplexities
        test_ppl = evaluate(model, 'test', tc=args.tc, td=args.td)

        ## get most used topics
        indices = torch.tensor(range(args.num_docs_train))
        indices = torch.split(indices, args.batch_size)
        thetaAvg = torch.zeros(1, args.num_topics).to(device)
        thetaWeightedAvg = torch.zeros(1, args.num_topics).to(device)
        cnt = 0
        for idx, ind in enumerate(indices):
            data_batch = data.get_batch(train_tokens, train_counts, ind, args.vocab_size, device)
            sums = data_batch.sum(1).unsqueeze(1)
            cnt += sums.sum(0).squeeze().cpu().numpy()
            if args.bow_norm:
                normalized_data_batch = data_batch / sums
            else:
                normalized_data_batch = data_batch
            theta, _ = model.get_theta(normalized_data_batch)
            thetaAvg += theta.sum(0).unsqueeze(0) / args.num_docs_train
            weighed_theta = sums * theta
            thetaWeightedAvg += weighed_theta.sum(0).unsqueeze(0)
            if idx % 100 == 0 and idx > 0:
                print('batch: {}/{}'.format(idx, len(indices)))
        thetaWeightedAvg = thetaWeightedAvg.squeeze().cpu().numpy() / cnt
        print('\nThe 10 most used topics are {}'.format(thetaWeightedAvg.argsort()[::-1][:10]))

        ## show topics
        beta = model.get_beta()
        topic_indices = list(np.random.choice(args.num_topics, 10)) # 10 random topics
        print('\n')
        for k in range(args.num_topics):#topic_indices:
            gamma = beta[k]
            top_words = list(gamma.cpu().numpy().argsort()[-args.num_words+1:][::-1])
            topic_words = [vocab[a] for a in top_words]
            print('Topic {}: {}'.format(k, topic_words))

        if args.train_embeddings:
            ## show etm embeddings 
            try:
                rho_etm = model.rho.weight.cpu()
            except:
                rho_etm = model.rho.cpu()
            queries = ['cleaner', 'refrigerate', 'tupperware', 'curry', 'baby', 'weather', 'buffet', 
                            'ninja', 'fingernail']
            print('\n')
            print('ETM embeddings...')
            for word in queries:
                print('word: {} .. etm neighbors: {}'.format(word, nearest_neighbors(word, rho_etm, vocab)))
            print('\n')



Visualizing model quality before training...
####################################################################################################
Visualize topics...
Topic 0: ['fingernail', 'slender', 'veggie', 'collector', 'weakness', 'latest', 'struggle', 'inaccurate', 'cleaner']
Topic 1: ['refrigerate', 'crusher', 'outdoors', 'custard', 'snob', 'specific', 'hint', 'nuisance', 'weather']
Topic 2: ['abit', 'junky', 'command', 'critter', 'perform', 'consume', 'thrill', 'arise', 'pull']
Topic 3: ['afraid', 'soooo', 'shiny', 'recipient', 'tupperware', 'shag', 'simmer', 'suitable', 'robust']
Topic 4: ['write', 'anymore', 'pocket', 'removeable', 'anticipate', 'curry', 'catcher', 'ikea', 'fibrox']
####################################################################################################
Visualize word embeddings by using output embedding matrix
vectors:  (3951, 300)
query:  (300,)
word: cleaner .. neighbors: ['cleaner', 'diameter', 'keeper', 'drill', 'fixture', 'fancier', 'treme

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 117.93 .. NELBO: 117.93
Epoch: 2 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 116.52 .. NELBO: 116.52
Epoch: 2 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 117.13 .. NELBO: 117.13
Epoch: 2 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 117.31 .. NELBO: 117.31
Epoch: 2 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 117.24 .. NELBO: 117.24
Epoch: 2 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 117.4 .. NELBO: 117.4
****************************************************************************************************
Epoch----->2 .. LR: 0.005 .. KL_theta: 0.0 .. Rec_loss: 117.5 .. NELBO: 117.5
****************************************************************************************************
****************************************************************************************************
VAL Doc Completion PPL: 1047.3
******************************************

****************************************************************************************************
VAL Doc Completion PPL: 1004.5
****************************************************************************************************
Epoch: 10 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 0.08 .. Rec_loss: 116.05 .. NELBO: 116.13
Epoch: 10 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 0.07 .. Rec_loss: 115.83 .. NELBO: 115.9
Epoch: 10 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 0.07 .. Rec_loss: 115.54 .. NELBO: 115.61
Epoch: 10 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 0.08 .. Rec_loss: 115.5 .. NELBO: 115.58
Epoch: 10 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 0.08 .. Rec_loss: 115.59 .. NELBO: 115.67
Epoch: 10 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 0.08 .. Rec_loss: 115.51 .. NELBO: 115.59
****************************************************************************************************
Epoch----->10 .. LR: 0.005 .. KL_theta: 0.08 .. Rec_loss: 115.39 .. NELBO: 115.47
**************************

Epoch: 14 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 0.14 .. Rec_loss: 115.28 .. NELBO: 115.42
****************************************************************************************************
Epoch----->14 .. LR: 0.005 .. KL_theta: 0.14 .. Rec_loss: 115.32 .. NELBO: 115.46
****************************************************************************************************
****************************************************************************************************
VAL Doc Completion PPL: 1004.0
****************************************************************************************************
Epoch: 15 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 0.17 .. Rec_loss: 116.3 .. NELBO: 116.47
Epoch: 15 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 0.17 .. Rec_loss: 114.99 .. NELBO: 115.16
Epoch: 15 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 0.16 .. Rec_loss: 114.91 .. NELBO: 115.07
Epoch: 15 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 0.16 .. Rec_loss: 115.11 .. NELBO: 115.27
Epoch: 15 .. batch

Epoch: 21 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 116.0 .. NELBO: 116.49
Epoch: 21 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 0.48 .. Rec_loss: 114.75 .. NELBO: 115.23
Epoch: 21 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 0.48 .. Rec_loss: 114.25 .. NELBO: 114.73
Epoch: 21 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 0.47 .. Rec_loss: 114.89 .. NELBO: 115.36
Epoch: 21 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 0.47 .. Rec_loss: 115.04 .. NELBO: 115.51
Epoch: 21 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 0.48 .. Rec_loss: 114.72 .. NELBO: 115.2
****************************************************************************************************
Epoch----->21 .. LR: 0.005 .. KL_theta: 0.49 .. Rec_loss: 114.53 .. NELBO: 115.02
****************************************************************************************************
****************************************************************************************************
VAL Doc Completion PPL: 978.1
***************************

****************************************************************************************************
VAL Doc Completion PPL: 936.8
****************************************************************************************************
Epoch: 29 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 0.99 .. Rec_loss: 114.66 .. NELBO: 115.65
Epoch: 29 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 0.96 .. Rec_loss: 113.96 .. NELBO: 114.92
Epoch: 29 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 0.94 .. Rec_loss: 114.31 .. NELBO: 115.25
Epoch: 29 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 0.94 .. Rec_loss: 113.84 .. NELBO: 114.78
Epoch: 29 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 0.94 .. Rec_loss: 113.63 .. NELBO: 114.57
Epoch: 29 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 0.95 .. Rec_loss: 113.45 .. NELBO: 114.4
****************************************************************************************************
Epoch----->29 .. LR: 0.005 .. KL_theta: 0.95 .. Rec_loss: 113.28 .. NELBO: 114.23
**************************

Epoch: 33 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 1.12 .. Rec_loss: 113.07 .. NELBO: 114.19
****************************************************************************************************
Epoch----->33 .. LR: 0.005 .. KL_theta: 1.13 .. Rec_loss: 112.91 .. NELBO: 114.04
****************************************************************************************************
****************************************************************************************************
VAL Doc Completion PPL: 923.7
****************************************************************************************************
Epoch: 34 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 1.26 .. Rec_loss: 112.8 .. NELBO: 114.06
Epoch: 34 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 1.21 .. Rec_loss: 112.67 .. NELBO: 113.88
Epoch: 34 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 1.17 .. Rec_loss: 113.18 .. NELBO: 114.35
Epoch: 34 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 1.17 .. Rec_loss: 112.42 .. NELBO: 113.59
Epoch: 34 .. batch:

Epoch: 41 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 1.36 .. Rec_loss: 112.96 .. NELBO: 114.32
Epoch: 41 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 1.36 .. Rec_loss: 113.32 .. NELBO: 114.68
Epoch: 41 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 1.36 .. Rec_loss: 113.06 .. NELBO: 114.42
Epoch: 41 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 1.35 .. Rec_loss: 112.96 .. NELBO: 114.31
Epoch: 41 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 1.35 .. Rec_loss: 112.67 .. NELBO: 114.02
Epoch: 41 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 1.35 .. Rec_loss: 112.39 .. NELBO: 113.74
****************************************************************************************************
Epoch----->41 .. LR: 0.005 .. KL_theta: 1.35 .. Rec_loss: 112.32 .. NELBO: 113.67
****************************************************************************************************
****************************************************************************************************
VAL Doc Completion PPL: 909.5
*************************

****************************************************************************************************
VAL Doc Completion PPL: 900.2
****************************************************************************************************
Epoch: 49 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 1.55 .. Rec_loss: 112.37 .. NELBO: 113.92
Epoch: 49 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 1.53 .. Rec_loss: 110.63 .. NELBO: 112.16
Epoch: 49 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 1.53 .. Rec_loss: 111.3 .. NELBO: 112.83
Epoch: 49 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 1.52 .. Rec_loss: 111.48 .. NELBO: 113.0
Epoch: 49 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 1.52 .. Rec_loss: 111.71 .. NELBO: 113.23
Epoch: 49 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 1.53 .. Rec_loss: 112.01 .. NELBO: 113.54
****************************************************************************************************
Epoch----->49 .. LR: 0.005 .. KL_theta: 1.53 .. Rec_loss: 111.84 .. NELBO: 113.37
***************************

Epoch: 53 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 1.61 .. Rec_loss: 112.04 .. NELBO: 113.65
****************************************************************************************************
Epoch----->53 .. LR: 0.005 .. KL_theta: 1.6 .. Rec_loss: 111.65 .. NELBO: 113.25
****************************************************************************************************
****************************************************************************************************
VAL Doc Completion PPL: 895.8
****************************************************************************************************
Epoch: 54 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 1.63 .. Rec_loss: 113.12 .. NELBO: 114.75
Epoch: 54 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 1.6 .. Rec_loss: 113.1 .. NELBO: 114.7
Epoch: 54 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 1.59 .. Rec_loss: 113.29 .. NELBO: 114.88
Epoch: 54 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 1.6 .. Rec_loss: 112.36 .. NELBO: 113.96
Epoch: 54 .. batch: 10/

Epoch: 61 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 1.65 .. Rec_loss: 112.2 .. NELBO: 113.85
Epoch: 61 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 1.66 .. Rec_loss: 111.9 .. NELBO: 113.56
Epoch: 61 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 1.68 .. Rec_loss: 111.94 .. NELBO: 113.62
Epoch: 61 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 1.67 .. Rec_loss: 111.9 .. NELBO: 113.57
Epoch: 61 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 1.66 .. Rec_loss: 111.86 .. NELBO: 113.52
Epoch: 61 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 1.66 .. Rec_loss: 111.45 .. NELBO: 113.11
****************************************************************************************************
Epoch----->61 .. LR: 0.005 .. KL_theta: 1.66 .. Rec_loss: 111.48 .. NELBO: 113.14
****************************************************************************************************
****************************************************************************************************
VAL Doc Completion PPL: 892.4
****************************

****************************************************************************************************
VAL Doc Completion PPL: 885.1
****************************************************************************************************
Epoch: 69 .. batch: 2/14 .. LR: 0.005 .. KL_theta: 1.81 .. Rec_loss: 111.87 .. NELBO: 113.68
Epoch: 69 .. batch: 4/14 .. LR: 0.005 .. KL_theta: 1.74 .. Rec_loss: 111.54 .. NELBO: 113.28
Epoch: 69 .. batch: 6/14 .. LR: 0.005 .. KL_theta: 1.69 .. Rec_loss: 111.5 .. NELBO: 113.19
Epoch: 69 .. batch: 8/14 .. LR: 0.005 .. KL_theta: 1.68 .. Rec_loss: 111.07 .. NELBO: 112.75
Epoch: 69 .. batch: 10/14 .. LR: 0.005 .. KL_theta: 1.7 .. Rec_loss: 111.14 .. NELBO: 112.84
Epoch: 69 .. batch: 12/14 .. LR: 0.005 .. KL_theta: 1.71 .. Rec_loss: 111.07 .. NELBO: 112.78
****************************************************************************************************
Epoch----->69 .. LR: 0.005 .. KL_theta: 1.72 .. Rec_loss: 111.34 .. NELBO: 113.06
***************************

In [None]:
evaluate(model, 'val', tc=args.tc, td=args.td)

In [None]:
visualize(model)