In [1]:
class Args: pass

In [97]:
# python
import os
import pickle
import time
import math
import random
from multiprocessing import cpu_count
from collections import OrderedDict, defaultdict

# nltk
import nltk
nltk.download('punkt')

# matplotlib
import matplotlib.pyplot as plt

# numpy
import numpy as np

# torch imports
import torch
from torch.utils.data import DataLoader

# ours
from corpus.ptb import PTB
from corpus.brown import Brown
from corpus.gutenberg import Gutenberg
from corpus.kjv import Bible
from corpus.wikitext_2 import Wikitext2
from corpus.wikitext_103 import Wikitext103
import util
from util.utils import to_var, expierment_name
from models.bowman import SentenceVAE

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


# Utilities

In [98]:
def create_datasets(args):
    # select correct corpus class
    assert args.corpus in ['ptb', 'bible', 'gutenberg', 'brown', 'wikitext-2', 'wikitext-103']
    if args.corpus == 'ptb':
        corpus_class = PTB
    elif args.corpus == 'kjv' or args.corpus == 'bible':
        nltk.download('gutenberg')
        corpus_class = Bible
    elif args.corpus == 'gutenberg':
        nltk.download('gutenberg')
        corpus_class = Gutenberg
    elif args.corpus == 'brown':
        nltk.download('brown')
        corpus_class = Brown
    elif args.corpus == 'wikitext-2':
        corpus_class = Wikitext2
    elif args.corpus == 'wikitext-103':
        corpus_class = Wikitext103
    
    # prepare for splits
    splits = [util.TRAIN, util.VAL] + ([util.TEST] if args.test else [])
    datasets = OrderedDict()
    datasets.splits = splits
    
    # create train, validation, and possibly test split
    for split in datasets.splits:
        datasets[split] = corpus_class(
            data_dir=args.data_dir,
            split=split,
            create_data=args.create_data,
            max_sequence_length=args.max_sequence_length,
            min_occ=args.min_occ,
            embeddings=args.embeddings
        )
    
    # return the splits
    return datasets

In [99]:
def create_model(args, datasets):
    model = SentenceVAE(
            vocab_size=datasets[util.TRAIN].vocab_size,
            sos_idx=datasets[util.TRAIN].sos_idx,
            eos_idx=datasets[util.TRAIN].eos_idx,
            pad_idx=datasets[util.TRAIN].pad_idx,
            unk_idx=datasets[util.TRAIN].unk_idx,
            max_sequence_length=args.max_sequence_length,
            embedding_size=args.embedding_size,
            rnn_type=args.rnn_type,
            hidden_size=args.hidden_size,
            word_dropout=args.word_dropout,
            embedding_dropout=args.embedding_dropout,
            latent_size=args.latent_size,
            num_layers=args.num_layers,
            bidirectional=args.bidirectional
        )
    if torch.cuda.is_available():
        model = model.cuda()
        
    return model

In [100]:
def kl_anneal_function(anneal_function, step, k, x0):
    if anneal_function == 'logistic':
        return float(1.0 / (1.0 + np.exp(-k * (step - x0))))
    elif anneal_function == 'linear':
        return min(1.0, step / x0)
    elif anneal_function == 'const':
        return 1.0

In [101]:
def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0, pad_idx):
    NLL = torch.nn.NLLLoss(reduction='sum', ignore_index=pad_idx)
    
    # cut-off unnecessary padding from target, and flatten
    target = target[:, :torch.max(length).item()].contiguous().view(-1)
    logp = logp.view(-1, logp.size(2))
        
    # Negative Log Likelihood
    NLL_loss = NLL(logp, target)

    # KL Divergence
    KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
    KL_weight = kl_anneal_function(anneal_function, step, k, x0)

    return NLL_loss, KL_loss, KL_weight

In [102]:
def idx2word(sents, i2w, pad_idx):
    sent_str = [str()] * len(sents)

    for sent_idx, sent in enumerate(sents):
        for word_id in sent:
            try:
                word_id = word_id.item()
            except: pass
            
            if word_id == pad_idx:
                break
            
            sent_str[sent_idx] += (i2w[word_id] + " ")

        sent_str[sent_idx] = sent_str[sent_idx].strip()


    return sent_str

# Model/Runtime Arguments

# Prep Dataset/Model

In [103]:
def train(model, datasets, args):
    print(model)
    
    timestamp = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    # create the directory for saving this model
    args.save_model_path = os.path.join(util.MODEL_DIR, timestamp)
    os.makedirs(args.save_model_path)
    
    # create the optimizer, the tracker, and initialize the step to 0
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    trackers = {split: defaultdict(list) for split in datasets.splits}
    step = 0
    best = float('inf')
    
    # get the pad index, for convenience
    pad_idx = datasets['train'].get_w2i()['<pad>']
    
    # go!
    for epoch in range(args.epochs):
        for split in datasets.splits:
            print("SPLIT = {}".format(split))
            
            data_loader = DataLoader(
                dataset=datasets[split],
                batch_size=args.batch_size,
                shuffle=split=='train',
                num_workers=cpu_count(),
                pin_memory=torch.cuda.is_available()
            )

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):
                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                    batch['length'], mean, logv, args.anneal_function, step, args.k, args.x0, pad_idx)

                loss = (NLL_loss + KL_weight * KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                    print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(),
                           iteration,
                           len(data_loader) - 1,
                           loss.item(),
                           NLL_loss.item() / batch_size,
                           KL_loss.item() / batch_size,
                           KL_weight))

                trackers[split]['ELBO'].append(loss.item())
                trackers[split]['NLL'].append(NLL_loss.item() / batch_size)
                trackers[split]['KLL'].append(KL_loss.item() / batch_size)
                trackers[split]['KL_weight'].append(KL_weight)
                
#                 if split == 'valid':
#                     i2w = datasets['train'].get_i2w()
#                     trackers[split]['target_sents'] += idx2word(batch['target'].data, i2w=i2w, pad_idx=pad_idx)
#                     trackers[split]['z'].append(z.tolist())

            
            """
            END OF BATCH
            """
            print("%s Epoch %02d/%i, Mean ELBO %9.4f" % (split.upper(), epoch, args.epochs, np.mean(trackers[split]['ELBO'])))

            # save a dump of all sentences and the encoded latent space
#             if split == 'valid':
#                 dump = {'target_sents':trackers[split]['target_sents'], 'z':trackers[split]['z']}
#                 if not os.path.exists(os.path.join('dumps', ts)):
#                     os.makedirs('dumps/' + ts)
#                 with open(os.path.join('dumps/'+ts+'/valid_E%i.pickle' % epoch), 'wb') as dump_file:
#                     pickle.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':                
                # save checkpoint
                checkpoint_path = os.path.join(args.save_model_path, "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)
                
                # check if best checkpoint so far
                if np.mean(trackers[split]['ELBO']) < best:
                    best = np.mean(trackers[split]['ELBO'])
                    args.load_checkpoint = 'E{}.pytorch'.format(epoch)
                
    return trackers, model

# Plotting Utilities

In [104]:
def exponential_smoothing(ys, beta=0.8, ub=math.inf, lb=-math.inf):
    """
    This is ugly, and I should have used a comprehension, but
    it'll get the job done. I made it a function because I suspect
    I may need it later.
    """
    smooth_ys = [ys[0]]
    for y in ys:
        if y > ub or y < lb:
            smooth_ys.append(smooth_ys[-1])
        else:
            smooth_ys.append(beta * smooth_ys[-1] + (1 - beta) * y)
    return smooth_ys[1:]

In [105]:
def plot(ELBO, NLL, KL, title, fname=None, xlabel="Epochs", ylabel="Measurements", hline=None, epochs=None):
    """
    Just a *slight* abstraction over pyplot to ease development a bit.
    """
    xs = list(range(len(ELBO)))
    if epochs is not None:
        xs = [x / len(xs) * epochs for x in xs]
    
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    
    if hline:
        plt.axhline(y=hline, color='r', linestyle='-')
    
    plt.plot(xs, ELBO, label="ELBO")
    plt.plot(xs, NLL, label="NLL Loss", c='blue')
    plt.plot(xs, KL, label="KL Loss", c='red')
    plt.legend()
    
    if fname:
        plt.savefig(fname)
    else:
        plt.show()
        
    plt.clf()

In [106]:
def plot_elbo(ELBO, fname=None, title='ELBO', xlabel="Epochs", ylabel="ELBO", hline=None, epochs=None):
    """
    Just a *slight* abstraction over pyplot to ease development a bit.
    """
    xs = list(range(len(ELBO)))
    if epochs is not None:
        xs = [x / len(xs) * epochs for x in xs]
    
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    
    plt.plot(xs, ELBO, label="ELBO")
    plt.legend()
        
    if fname:
        plt.savefig(fname)
    else:
        plt.show()
        
    plt.clf()

In [107]:
def graph(trackers, datasets, args):
    for split in datasets.splits:
        fname = '{}_perf:emb{}-z{}-lstm{}-maxlen{}'.format(
            split,
            args.embedding_size,
            args.latent_size,
            args.hidden_size,
            args.max_sequence_length
        )
        
        fname = os.path.join(args.save_model_path, fname)
        
        plot(
            fname=fname,
            ELBO=exponential_smoothing(trackers[split]['ELBO']),
            KL=exponential_smoothing(trackers[split]['KLL']),
            NLL=exponential_smoothing(trackers[split]['NLL']),
            title='S-VAE *{}* Performance\n(Mikolov\'s Simplified PTB, max length={})'.format(
                split,
                args.max_sequence_length
            ),
            epochs=args.epochs
        )

# Running the Experiments

In [108]:
def interpolate(start, end, steps):
    steps = steps + 2
    
    interpolation = np.zeros((start.shape[0], steps))

    for dim, (s, e) in enumerate(zip(start, end)):
        interpolation[dim] = np.linspace(s, e, steps)

    return interpolation.T

In [109]:
def save_args(args):
    fname = os.path.join(args.save_model_path, 'args')
    with open(fname, 'w+') as file:
        lines = ['{}: {}\n'.format(key, val) for key, val in vars(args).items()]
        file.writelines(lines)

In [110]:
def save_trackers(trackers, args):
    fname = os.path.join(args.save_model_path, 'trackers.pickle')
    with open(fname, 'wb') as file:
        pickle.dump(trackers, file)

In [111]:
def test(args):
    with open(args.data_dir + '/ptb.vocab.pickle', 'rb') as file:
        vocab = pickle.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(
        vocab_size=len(w2i),
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )

    checkpoint_path = os.path.join(args.save_model_path, args.load_checkpoint)
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(checkpoint_path)

    model.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
    print("Model loaded from %s" % (checkpoint_path))

    if torch.cuda.is_available():
        model = model.cuda()
    
    model.eval()

    fname = os.path.join(args.save_model_path, 'samples')
    lines = []
    with open(fname, 'w+') as file:
        samples, z = model.inference(n=args.num_samples)
        lines += ['----------SAMPLES----------']
        lines += [line + '\n' for line in idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])]
        lines += ['\n']

        z1 = torch.randn([args.latent_size]).numpy()
        z2 = torch.randn([args.latent_size]).numpy()
        z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
        samples, _ = model.inference(z=z)
        lines += ['-------SELF-GENERATED INTERPOLATION-------']
        lines += [line + '\n' for line in idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])]
        lines += ['\n']

        # pick two random sentences
        i = random.randint(0, len(datasets['train']))
        j = random.randint(0, len(datasets['train']))

        s_i = torch.tensor([datasets['train'][i]['input']])
        s_j = torch.tensor([datasets['train'][j]['input']])

        with torch.no_grad():
            _, _, _, z_i = model(s_i, torch.tensor([1]))
            _, _, _, z_j = model(s_j, torch.tensor([1]))
            
        z1, z2 = z_i.squeeze().numpy(), z_j.squeeze().numpy()
        z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
        samples, _ = model.inference(z=z)
        lines += ['-------DATA-DRIVEN INTERPOLATION----------']
        lines += [line + '\n'  for line in idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])]
        lines += ['\n']
        
        print("wrote samples to '{}'".format(fname))
        file.writelines(lines)
        
    return lines

In [112]:
def count_sentences(datasets):
    data = datasets['train'] + datasets['valid'] + datasets['test']
    
    return len(data)

In [113]:
def count_words(datasets):
    data = datasets['train'] + datasets['valid'] + datasets['test']
    total = 0
    for sent in data:
        total = total + sent['length']
        
    return total

In [158]:
# set all model/runtime arguments

args = Args()

args.data_dir = 'data'
args.create_data = True
args.max_sequence_length = 50
args.min_occ = 1
args.test = True
args.epochs = 10
args.batch_size = 64
args.learning_rate = 0.001

args.corpus = 'bible'

args.num_samples = 10

args.embeddings = True
args.embedding_size = 300
args.rnn_type = 'gru'
args.hidden_size = 512
args.num_layers = 1
args.bidirectional = False
args.latent_size = 32
args.word_dropout = 0.0
args.embedding_dropout = 0.5

args.anneal_function = 'logistic'
args.k = 0.0025
args.x0 = 2500

args.print_every = 50
args.tensorboard_logging = False
args.logdir = 'logs'
args.save_model_path = 'bin/good25'
args.load_checkpoint = 'E9.pytorch'

args.rnn_type = args.rnn_type.lower()
args.anneal_function = args.anneal_function.lower()

assert args.rnn_type in ['rnn', 'lstm', 'gru']
assert args.anneal_function in ['logistic', 'linear', 'const']
assert 0 <= args.word_dropout <= 1
assert args.corpus in ['ptb', 'bible', 'gutenberg', 'brown', 'wikitext-2', 'wikitext-103']

In [159]:
def run_experiment(args):
    # create the datasets and model
    datasets = create_datasets(args)

    # create a new model
    model = create_model(args, datasets)
    
    # train the model and record its performance
    trackers, model = train(model, datasets, args)
    
    # write args to file
    save_args(args)
    
    # save the trackers
    save_trackers(trackers, args)
    
    # graph the results and save
    graph(trackers, datasets, args)

In [160]:
datasets = run_experiment(args)

UnboundLocalError: local variable 'corpus_class' referenced before assignment

# Testing/Generating Samples

In [None]:
# generate samples and interpolations
# test(args)