## Deep Markov model

In [1]:
import os
import re
import pickle
import numpy as np
from random import shuffle

from collections import Counter
from sklearn.model_selection import train_test_split
from time import time

import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist

from pyro import poutine
from pyro import optim
from pyro import infer

In [2]:
DATA = 'simple_english_wikipedia/'
CORPUS = 'corpus.txt'
KEEP_WORDS = 100

def load_sentences_from_raw(keep_words=KEEP_WORDS):
    # load the text, remove extra characters, and split into sentences
    with open(os.path.join(DATA, CORPUS)) as fh:
        # load and make lower
        text = fh.read().lower()

    # remove non-word, space, or period characters
    text = re.sub(r'[^\s\dA-Za-z.]', '', text)
    # get rid of headings
    text = re.sub(r'\n\S+\n', ' ', '\n' + text)
    # change all whitespace to single space
    text = re.sub(r'\s+', ' ', text).strip()
    # replace numbers with "NUM"
    text = re.sub(r'\d+', 'NUM', text)
    # compile the most common words
    # split on periods
    sentences = text.split('. ')
    # identify the most common words
    counter = Counter(text.replace('.', '').split())
    most_common = {w[0]: torch.tensor(i+1) \
                   for i, w in enumerate(counter.most_common(keep_words))}
    with open(os.path.join(DATA, 'hash_dict.pkl'), 'wb') as fh:
        pickle.dump(most_common, fh)
    # create hashed versions of the sentences
    def hash_sentence(sentence):
        iterator =  (most_common.get(w) for w in sentence.split())
        # encode uncommon words as zeros
        return torch.stack([x if x is not None else torch.tensor(0) \
                            for x in iterator]).to(torch.int64)
    hashed_sentences = [hash_sentence(sentence) for sentence in sentences if sentence]
    shuffle(hashed_sentences)
    return hashed_sentences

def load_sentences(group, use_prehash=True, keep_words=100, n_groups=50):
    if use_prehash == False or not os.path.isfile(os.path.join(DATA, f'prehash{group}.pkl')):
        hashed_sentences = load_sentences_from_raw(keep_words)
        total_size = len(hashed_sentences)
        step = total_size // (n_groups - (0 if total_size % n_groups == 0 else 1))
        for g, i, in enumerate(range(0, total_size, step)):
            with open(os.path.join(DATA, f'prehash{g}.pkl'), 'wb') as fh:
                pickle.dump(hashed_sentences[i:(i+step)], fh)
        return hashed_sentences[step*group:step*(group+1)]
    else:
        with open(os.path.join(DATA, f'prehash{group}.pkl'), 'rb') as fh:
            hashed_sentences = pickle.load(fh)
    return hashed_sentences

def get_group(group):
    hashed_sentences = load_sentences(group)
    sentence_lengths = torch.from_numpy(np.array([len(sentence) for sentence in hashed_sentences]))
    # change hashed_sentences to padded tensors
    hashed_sentences = nn.utils.rnn.pad_sequence(hashed_sentences, batch_first=True)
    # make one-hot so it's compatible with rnn
    hashed_sentences = nn.functional.one_hot(hashed_sentences)
    return hashed_sentences, sentence_lengths

In [3]:
def make_mask(batch, seq_lengths):
    mask = torch.zeros(batch.shape[0:2])
    for i in range(batch.shape[0]):
        mask[i, 0:seq_lengths[i]] = torch.ones(seq_lengths[i])
    return mask
    
def reverse_seqs(batch, seq_lengths):
    """Utility function for reversing rnn and mini batch vectors
    """
    reversed_seq = torch.zeros_like(batch)
    for i in range(batch.size(0)):
        T = seq_lengths[i]
        time_slice = torch.arange(T-1, -1, -1)
        reversed_seq[i, 0:T, :] = torch.index_select(batch[i, :, :], 0, time_slice)
    return reversed_seq

def prep(x, x_lengths):
    """Given batch x and lengths x_lengths, sorts and creates mask and reversed x"""
    # sort from longest to shortest
    sorted_length_indices = torch.argsort(x_lengths, descending=True)
    x_lengths = x_lengths[sorted_length_indices]
    # cut off the unnecessary padding
    max_length = torch.max(x_lengths)
    x = x[sorted_length_indices, 0:max_length, :]
    # reverse and pack to prepare for input to rnn
    x_reversed = nn.utils.rnn.pack_padded_sequence(
        reverse_seqs(x, x_lengths),
        x_lengths,
        batch_first=True
    )
    x_mask = make_mask(x, x_lengths)
    return x, x_reversed, x_mask, x_lengths

In [4]:
class Emitter(nn.Module):
    """Produces an output vector p(xt | zt), the probability params of a given xt given
        the latent variables zt at node t
    xt is a vector of 1s and 0s, so this essentially generates a vectorized Bernoulli likelihood
    """
    def __init__(self, z_dim, x_dim, hidden_dim):
        super().__init__()
        # initialize layers for network with 2 hidden layers
        self.lin1 = nn.Linear(z_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)
        self.lin3 = nn.Linear(hidden_dim, x_dim)
        # init nonlinearities
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, zt):
        h1 = self.relu(self.lin1(zt))
        h2 = self.relu(self.lin2(h1))
        probs = self.sigmoid(self.lin3(h2))
        return probs

        
class Transition(nn.Module):
    """Given zt_1, produces loc and scale parameters for the distribution of zt
    Making this a gated transition might improve fit?
    """
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # init layers for loc and scale, one hidden for each
        self.z_to_hidden_loc = nn.Linear(z_dim, hidden_dim)
        self.hidden_to_loc = nn.Linear(hidden_dim, z_dim)
        self.z_to_hidden_scale = nn.Linear(z_dim, hidden_dim)
        self.hidden_to_scale = nn.Linear(hidden_dim, z_dim)
        # init nonlinearities
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
    
    def forward(self, zt_1):
        loc = self.hidden_to_loc(
            self.relu(self.z_to_hidden_loc(zt_1))
        )
        scale = self.softplus(self.hidden_to_scale(
            self.relu(self.z_to_hidden_scale(zt_1))
        ))
        return loc, scale

    
class GuideNet(nn.Module):
    """Combines RNN output and zt_1 to produce parameters for variational distribution
    variational dist. has form q(zt | zt_1, x{t:T})
    The job here is analagous to what Transition does, but for the variational dist.
    """
    def __init__(self, z_dim, rnn_dim):
        super().__init__()
        # init linear transformations, 1 hidden layer for loc and scale
        self.z_to_hidden = nn.Linear(z_dim, rnn_dim)
        self.hidden_to_loc = nn.Linear(rnn_dim, z_dim)
        self.hidden_to_scale = nn.Linear(rnn_dim, z_dim)
        # init nonlinearities
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
    
    def forward(self, zt_1, rnn_out):
        # take mean of first layer and rnn_out
        combined = 0.5 * (self.relu(self.z_to_hidden(zt_1)) + rnn_out)
        loc = self.hidden_to_loc(combined)
        scale = self.softplus(self.hidden_to_scale(combined))
        return loc, scale

   
class DeepMarkov(nn.Module):
    
    def __init__(self, x_dim=KEEP_WORDS+1, z_dim=50, transition_dim=100, emitter_dim=50,
                 rnn_dim=200, rnn_dropout=0.0):
        super().__init__()
        # init neural nets used in model and guide
        self.emitter = Emitter(z_dim, x_dim, emitter_dim)
        self.transition = Transition(z_dim, transition_dim)
        self.rnn = nn.RNN(x_dim, hidden_size=rnn_dim, nonlinearity='relu',
                          batch_first=True, dropout=rnn_dropout)
        self.guideNet = GuideNet(z_dim, rnn_dim)
        # define trainable parameters for the z's on the first node
        # z0 is for the model; zq0 is for the guide
        self.z0 = nn.Parameter(torch.zeros(z_dim))
        self.zq0 = nn.Parameter(torch.zeros(z_dim))
        self.svi = None
    
    def model(self, x, x_reversed, x_mask, x_lengths, annealing_factor=1.0):
        T = x.size(1)
        # register self with Pyro
        pyro.module('deep_markov', self)
        # z_prev is z on the previous node; expand to have same length as x
        z_prev = self.z0.expand(x.size(0), self.z0.size(0))
        with pyro.plate('z_batch', len(x)):
            for t in range(1, T+1):
                # sample zt ~ p(zt | zt_1)
                z_loc, z_scale = self.transition(z_prev)  # get params
                # scale log probabilities for KL annealing
                with poutine.scale(None, annealing_factor):
                    # mask to deal with uneven lengths on x vector
                    zt = pyro.sample(f'z{t}', dist.Normal(z_loc, z_scale)
                                     .mask(x_mask[:, (t-1):t]).to_event(1))
                # compute likelihood p(xt | zt)
                emitter_probs = self.emitter(zt)
                pyro.sample(f'obs_x{t}', dist.OneHotCategorical(probs=emitter_probs)
                            .mask(x_mask[:, (t-1):t]).to_event(1),
                            obs=x[:, t-1, :])
                # set new z_prev as this zt
                z_prev = zt
    
    def guide(self, x, x_reversed, x_mask, x_lengths, annealing_factor=1.0):
        T = x.size(1)
        pyro.module('deep_markov', self)
        rnn_out, _ = self.rnn(x_reversed.float())
        rnn_out, _ = nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
        rnn_out = reverse_seqs(rnn_out, x_lengths)
        z_prev = self.z0.expand(x.size(0), self.z0.size(0))
        with pyro.plate('z_batch', len(x)):
            for t in range(1, T-1):
                # sample from q(zt | zt_1, x{t:T})
                z_loc, z_scale = self.guideNet(z_prev, rnn_out[:, (t-1), :])
                with pyro.poutine.scale(None, annealing_factor):
                    zt = pyro.sample(f'z{t}', dist.Normal(z_loc, z_scale)
                                     .mask(x_mask[:, (t-1):t]).to_event(1))
                z_prev = zt
    
    def fit(self, x_train, x_train_lengths,
            lr=0.001, clip_norm=10.0, batch_size=32, n_epochs=1,
            min_anneal=0.2, anneal_step_per_batch=0.001):
        if self.svi is None:
            optimizer = optim.ClippedAdam({'lr': lr, 'clip_norm': clip_norm})
            self.svi = infer.SVI(self.model, self.guide, optimizer, infer.Trace_ELBO())
        total_size = len(x_train)
        annealing_factor = min_anneal
        print(f'batches per epoch: {total_size // batch_size + 1}')
        losses = []
        for epoch in range(n_epochs):
            epoch_loss = 0
            time0 = time()
            for i in range(0, total_size, batch_size):
                x = x_train[i:(i+batch_size)]
                x_lengths = x_train_lengths[i:(i+batch_size)]
                x, x_reversed, x_mask, x_lengths = prep(x, x_lengths)
                annealing_factor = min(1.0, annealing_factor + anneal_step_per_batch)
                epoch_loss += self.svi.step(x, x_reversed, x_mask, x_lengths, annealing_factor)
            losses.append(epoch_loss / total_size)
            print(f'Epoch {epoch+1}: loss = {epoch_loss / total_size:.5f}, time = {time() - time0:.2f}')
        return losses
    
    def evaluate(self, x_test, x_test_lengths, batch_size=32):
        assert self.svi is not None, 'Must run fit first'
        total_size = len(x_test)
        loss = 0
        for i in range(0, total_size, batch_size):
            x = x_test[i:(i+batch_size)]
            x_lengths = x_test_lengths[i:(i+batch_size)]
            x, x_reversed, x_mask, x_lengths = prep(x, x_lengths)
            loss += self.svi.evaluate_loss(x, x_reversed, x_mask, x_lengths)
        print(f'test loss is {loss / total_size}')
        return loss / total_size

In [5]:
x_train, x_train_lengths = get_group(0)
x_test, x_test_lengths = get_group(1)

pyro.clear_param_store()
deepMarkov = DeepMarkov()
deepMarkov.fit(x_train, x_train_lengths, n_epochs=2)
deepMarkov.evaluate(x_test, x_test_lengths)

batches per epoch: 233
Epoch 1: loss = 1119.55199, time = 90.46


  warn_if_nan(loss, "loss")


Epoch 2: loss = nan, time = 93.69


  warn_if_nan(loss, "loss")


test loss is {loss / total_size}


nan