In [1]:
import torch
import numpy as np
import torch.nn as nn
import pickle

from utils.data import preprocess, word2idx, Dictionary
from nltk import sent_tokenize

[nltk_data] Downloading package punkt to /home/shaderein/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
print(torch.__version__)

1.2.0


# Load pretrained models
`/LSTM_40m` contains a batch of pretrained models. The name convention is:

`LSTM_[Hidden Units]_[Training Tokens]_[Training Partition]_[Random Seed]-d[Dropout Rate].pt`

The following analysis was done on models with 3 different hidden sizes [100/400/1600]

In [4]:
model_100_file = './data/LSTM_40m/LSTM_100_40m_a_0-d0.2.pt'
model_400_file = './data/LSTM_40m/LSTM_400_40m_a_10-d0.2.pt'
model_1600_file = './data/LSTM_40m/LSTM_1600_40m_a_20-d0.2.pt'

model_100 = torch.load(model_100_file, map_location=torch.device('cpu'))
model_100.eval()

model_400 = torch.load(model_400_file, map_location=torch.device('cpu'))
model_400.eval()

model_1600 = torch.load(model_1600_file, map_location=torch.device('cpu'))
model_1600.eval()


print(model_100)

RNNModel(
  (drop): Dropout(p=0.2, inplace=False)
  (encoder): Embedding(28439, 100)
  (rnn): LSTM(100, 100, num_layers=2, dropout=0.2)
  (decoder): Linear(in_features=100, out_features=28439, bias=True)
)


# Load vocab and text files

Load and preprocess the story file. (Lowercase and append "EOS")

Mark highlighted sentences

In [5]:
story_file = "./data/text/Full Story_So much water so close to home_targetedQ_Highlighted.docx"
vocab_file = "./data/vocab.txt"

vocab = Dictionary(vocab_file)
vocab_size = len(vocab.word2idx)
story_text, is_highlight = preprocess(story_file)
processed_story, story_sents = word2idx(story_text, vocab)

sample_sent_idx = 2
print(story_text[sample_sent_idx])
print(processed_story[sample_sent_idx])
print(is_highlight[sample_sent_idx])
print(story_sents[sample_sent_idx])

He chews, arms on the table, and stares at something across the room.
['<eos>', 'he', '<unk>', ',', 'arms', 'on', 'the', 'table', ',', 'and', '<unk>', 'at', 'something', 'across', 'the', 'room', '.']
False
tensor([[28438],
        [   18],
        [28437],
        [    1],
        [ 1124],
        [   13],
        [    0],
        [ 1571],
        [    1],
        [    5],
        [28437],
        [   22],
        [  574],
        [  475],
        [    0],
        [  804],
        [    2]])


To test the validity of `sent_perplexity`, run each model again on an excerpt from [test.txt](https://github.com/vansky/neural-complexity/blob/master/data/wikitext-2/test.txt) used to evaluate the language model. (From wikitext-2 dataset)

In [6]:
test_file = "./data/test.txt"
with open(test_file, "r") as text_file:
    test_text = text_file.read().replace('\n', '')
test_text = sent_tokenize(test_text)
processed_test, test_sents = word2idx(test_text, vocab)

print(test_text[1])
print(processed_test[1])

It caused enormous disruption to Chinese society : the census of 754 recorded 52 @.
['<eos>', 'it', 'caused', 'enormous', 'disruption', 'to', '<unk>', 'society', ':', 'the', 'census', 'of', '<unk>', 'recorded', '<unk>', '@', '.']


# Calculate PPL per sentence

Calculate sentence PPL for each highlight and averaged over them.

For each sentence, PPL is calculated as `exp(cross_entropy_loss[prediction, target])`, where `prediction` includes output from reading w[-1] (the last output from the previous sentence) all the way to w[n-2] and `target` includes w[0],...,w[n-1] (the last word of the current sentence, which is always 'EOS' now after preprocessing.)

In [7]:
from utils.analysis import sent_perplexity
models = [model_100, model_400, model_1600]

## Story text

In [9]:
with torch.no_grad():
    for model in models:
        ppl_sent_highlight = []
        hidden_size = model.nhid
        for i, sent in enumerate(story_sents):
            if i==0:
                hidden = model.init_hidden(bsz=1)
                out, hidden = model(sent, hidden)
                continue
            ppl, out, hidden = sent_perplexity(sent, model, vocab, hidden)
            if is_highlight[i]:
                ppl_sent_highlight.append(ppl)

        print(f"Model_{hidden_size} ppl {np.mean(ppl_sent_highlight)}")

Model_100 ppl 248.5772705078125
Model_400 ppl 142.5596160888672
Model_1600 ppl 118.00228881835938


### Without Context

In [10]:
with torch.no_grad():
    for model in models:
        ppl_sent_highlight = []
        hidden_size = model.nhid
        # initialization
        hidden_init = model.init_hidden(bsz=1)
        for i, sent in enumerate(story_sents):
            ppl, out, hidden = sent_perplexity(sent, model, vocab, hidden_init)
            if is_highlight[i]:
                ppl_sent_highlight.append(ppl)

        print(f"Model_{hidden_size} ppl {np.mean(ppl_sent_highlight)}")

Model_100 ppl 402.2755126953125
Model_400 ppl 254.4184112548828
Model_1600 ppl 327.157470703125


## Test text from wikitext-2

In [11]:
with torch.no_grad():
    for model in models:
        ppl_sent = []
        hidden_size = model.nhid
        for i, sent in enumerate(test_sents):
            if i==0:
                hidden = model.init_hidden(bsz=1)
                out, hidden = model(sent, hidden)
                continue
            ppl, out, hidden = sent_perplexity(sent, model, vocab, hidden)
            if is_highlight[i]:
                ppl_sent.append(ppl)

        print(f"Model_{hidden_size} ppl {np.mean(ppl_sent)}")

Model_100 ppl 76.89266967773438
Model_400 ppl 49.88071060180664
Model_1600 ppl 54.01231384277344


### Without context

In [12]:
with torch.no_grad():
    for model in models:
        ppl_sent_highlight = []
        hidden_size = model.nhid
        # initialization
        hidden_init = model.init_hidden(bsz=1)
        for i, sent in enumerate(test_sents):
            ppl, out, hidden = sent_perplexity(sent, model, vocab, hidden_init)
            if is_highlight[i]:
                ppl_sent_highlight.append(ppl)

        print(f"Model_{hidden_size} ppl {np.mean(ppl_sent_highlight)}")

Model_100 ppl 88.2619400024414
Model_400 ppl 55.22896194458008
Model_1600 ppl 57.528934478759766
