In [1]:
!pip install dynet
!git clone https://github.com/neubig/nn4nlp-code.git

Collecting dynet
[?25l  Downloading https://files.pythonhosted.org/packages/4f/de/181a8380e9fdb89d9aa5838059336bb535503d5f2053e621438e69081407/dyNET-2.0.3-cp27-cp27mu-manylinux1_x86_64.whl (27.6MB)
[K    100% |████████████████████████████████| 27.6MB 1.1MB/s 
Collecting cython (from dynet)
[?25l  Downloading https://files.pythonhosted.org/packages/f6/23/ef5521e077e9e7ef8e4603e27713ae95fee69e9c19c7cd036b4299c7ced5/Cython-0.28.3-cp27-cp27mu-manylinux1_x86_64.whl (3.3MB)
[K    100% |████████████████████████████████| 3.3MB 10.5MB/s 
[?25hInstalling collected packages: cython, dynet
Successfully installed cython-0.28.3 dynet-2.0.3
Cloning into 'nn4nlp-code'...
remote: Counting objects: 372, done.[K
remote: Total 372 (delta 0), reused 0 (delta 0), pack-reused 372[K
Receiving objects: 100% (372/372), 6.33 MiB | 19.81 MiB/s, done.
Resolving deltas: 100% (131/131), done.


In [0]:
from __future__ import print_function
import time

start = time.time()

from collections import Counter, defaultdict
import random
import math
import sys
import argparse

import dynet as dy
import numpy as np

In [0]:
# format of files: each line is "word1 word2 ..."
train_file = "nn4nlp-code/data/ptb/train.txt"
test_file = "nn4nlp-code/data/ptb/valid.txt"

w2i = defaultdict(lambda: len(w2i))


def read(fname):
    """
    Read a file where each line is of the form "word1 word2 ..."
    Yields lists of the form [word1, word2, ...]
    """
    with open(fname, "r") as fh:
        for line in fh:
            sent = [w2i[x] for x in line.strip().split()]
            sent.append(w2i["<s>"])
            yield sent


train = list(read(train_file))
nwords = len(w2i)
test = list(read(test_file))
S = w2i["<s>"]
assert (nwords == len(w2i))

In [0]:
# DyNet Starts
model = dy.Model()
trainer = dy.AdamTrainer(model)

# Lookup parameters for word embeddings
EMBED_SIZE = 64
HIDDEN_SIZE = 128
WORDS_LOOKUP = model.add_lookup_parameters((nwords, EMBED_SIZE))

# Word-level LSTM (layers=1, input=64, output=128, model)
RNN = dy.LSTMBuilder(1, EMBED_SIZE, HIDDEN_SIZE, model)

# Softmax weights/biases on top of LSTM outputs
W_sm = model.add_parameters((nwords, HIDDEN_SIZE))
b_sm = model.add_parameters(nwords)


# Build the language model graph
def calc_lm_loss(sent):
    dy.renew_cg()
    # parameters -> expressions
    W_exp = dy.parameter(W_sm)
    b_exp = dy.parameter(b_sm)

    # initialize the RNN
    f_init = RNN.initial_state()

    # get the wids and masks for each step
    tot_words = len(sent)

    # start the rnn by inputting "<s>"
    s = f_init.add_input(WORDS_LOOKUP[S])

    # feed word vectors into the RNN and predict the next word
    losses = []
    for wid in sent:
        # calculate the softmax and loss
        score = W_exp * s.output() + b_exp
        loss = dy.pickneglogsoftmax(score, wid)
        losses.append(loss)
        # update the state of the RNN
        wemb = WORDS_LOOKUP[wid]
        s = s.add_input(wemb)

    return dy.esum(losses), tot_words


# Sort training sentences in descending order and count minibatches
train_order = range(len(train))

In [0]:
print("startup time: %r" % (time.time() - start))
# Perform training
start = time.time()
i = all_time = dev_time = all_tagged = this_words = this_loss = 0
for ITER in range(100):
    random.shuffle(train_order)
    for sid in train_order:
        i += 1
        if i % int(500) == 0:
            trainer.status()
            print(this_loss / this_words, file=sys.stderr)
            all_tagged += this_words
            this_loss = this_words = 0
            all_time = time.time() - start
        if i % int(10000) == 0:
            dev_start = time.time()
            dev_loss = dev_words = 0
            for sent in test:
                loss_exp, mb_words = calc_lm_loss(sent)
                dev_loss += loss_exp.scalar_value()
                dev_words += mb_words
            dev_time += time.time() - dev_start
            train_time = time.time() - start - dev_time
            print("nll=%.4f, ppl=%.4f, words=%r, time=%.4f, word_per_sec=%.4f" % (
            dev_loss / dev_words, math.exp(dev_loss / dev_words), dev_words, train_time, all_tagged / train_time))
        # train on the minibatch
        loss_exp, mb_words = calc_lm_loss(train[sid])
        this_loss += loss_exp.scalar_value()
        this_words += mb_words
        loss_exp.backward()
        trainer.update()
    print("epoch %r finished" % ITER)
    trainer.update_epoch(1.0)

startup time: 1.5743510723114014


7.13827299469
6.68372788932
6.56000968153
6.48601439531
6.41959502674
6.37209053545
6.2874563398
6.24974389991
6.20393947699
6.22451080823
6.03801972917
6.12050731624
6.06029874266
6.03172786914
5.98755370178
5.97345763071
5.93895772363
5.8454166822
5.86689922595
5.81222066633


nll=5.8437, ppl=345.0623, words=73760, time=541.3720, word_per_sec=409.1290


5.79615720606
5.80938518573
5.72593288831
5.77341154408
5.74417290724
5.67134200219
5.60873633325
5.71535864284
5.76652435086
5.72160954594
5.63925532553
5.66350005485
5.63274236004
5.69882071044
5.57804742479
5.58517945686
5.56072322416
5.6069517123
5.49212287856
5.66331230684

nll=5.5961, ppl=269.3728, words=73760, time=1102.3300, word_per_sec=400.6649



5.60326360614
5.56828683159
5.52155342255
5.54291894474
5.5438365297
5.55297619938
5.45302811603
5.43357549258
5.50415297873
5.48438971178
5.50856491676
5.48664953523
5.44988535703
5.43145093157
5.54174557058
5.49743947081
5.46964364476
5.46231223525
5.5778975564
5.4585915418


nll=5.4760, ppl=238.8821, words=73760, time=1617.6296, word_per_sec=409.2037


5.40576742058
5.49219760756
5.42438429566
5.41305421014
5.34943064159
5.40813076417
5.42343847327
5.41369441914
5.40307380458
5.48832826515
5.43962900312
5.43060751134
5.34307060646
5.39962206578
5.31177674769
5.34375026923
5.34094594876
5.34451068655
5.31610741409
5.38514992829


nll=5.3938, ppl=220.0475, words=73760, time=2138.7431, word_per_sec=412.9056


5.41245563444
5.32780871376
5.41199616069
5.32797385257


epoch 0 finished


5.29203374717
5.24316151261
5.11191251485
5.17207399117
5.22250087449
5.18526077848
5.16826281371
5.22363877665
5.16923231767
5.16053822114
5.11309971808
5.28149341506
5.28101823841
5.22604056078
5.19269865205
5.22525116564


nll=5.3521, ppl=211.0434, words=73760, time=2652.7506, word_per_sec=416.3788


5.15558300949
5.10378297743
5.20394305653
5.25490003954
5.21285855904
5.22243514138
5.29237304725
5.28920646352
5.1787725695
5.21304962849
5.08591081371
5.10148259892
5.22727288526
5.12529780215
5.18294285856
5.15041936569
5.14122186468
5.16989877137
5.13088498173
5.24200942079


nll=5.3176, ppl=203.8883, words=73760, time=3172.0679, word_per_sec=418.2051


5.16379000838
5.2519836188
5.15643274643
5.19646079856
5.08433802657
5.18758335985
5.18093749059
5.03179717961
5.18113488226
5.11118929767
5.21866854659
5.19067934323
5.22487256346
5.1281171775
5.176186534
5.08659060371
5.18320135033
5.10718225908
5.1643815706
5.17418917686


nll=5.2919, ppl=198.7180, words=73760, time=3684.0643, word_per_sec=419.8282


5.08725284182
5.18750657707
5.17996684319
5.18571003543
5.2189700388
5.06652752732
5.21722251273
5.07079000215
5.11676094104
5.1967884564
5.15276457278
5.16796241171
5.21695697846
5.10565969401
5.15014041062
5.12258350957
5.1245775891
5.15459356652
5.09261295989
5.1604620696


nll=5.2551, ppl=191.5343, words=73760, time=4254.0007, word_per_sec=415.4240


5.18817154271
5.16895981322
5.07355719775
5.18825254981
5.12671319326
5.1749644008
5.0713480581
5.1159221129


epoch 1 finished


4.97057674728
4.9944453667
4.95889981725
4.89614205198
4.89073198472
4.97882679847
4.90557561336
5.02620715516
4.99074187437
4.94921378772
4.96216497786
4.92798497437


nll=5.2529, ppl=191.1192, words=73760, time=4803.4075, word_per_sec=414.1481


5.0455947966
4.86722804174
4.99976006317
5.03526972435
4.97199427063
4.95130373489
4.9628480353
4.9574815688
5.00569229414
5.00140695011
4.95769079827
4.96788953568
4.97780652424
5.00976601968
4.97022266184
5.03123904694
5.01865884999
5.01502736428
4.96678453556
4.96686009381


nll=5.2250, ppl=185.8543, words=73760, time=5635.0056, word_per_sec=392.4083


5.0134905173
4.96511253147
4.96838872069
5.03218391085
5.07595872302
4.95248603217
4.9995731367
5.07845263449
4.96658564516
4.9191736337
4.97647522795
4.9815864571
5.04884630594
4.94204309024
4.98035747864
5.02645306851
5.01410484473
4.98776444579
4.99302262815
4.97139605216


nll=5.2085, ppl=182.8206, words=73760, time=6476.2739, word_per_sec=375.3155


5.1019404817
5.08724183105
5.03247539222
5.00815061508
4.96571629127
5.00955189134
4.99214627459
5.03696748711
5.04215865412
5.01705293984
5.01737535295
4.99842077604
4.96613429321
5.02256667209
5.03518358678
4.99057816365
5.0144504009
4.98816747845
4.99500405967
4.9627097043


nll=5.1935, ppl=180.0911, words=73760, time=7310.1414, word_per_sec=362.8168


5.04292155946
5.00946773848
4.97607301428
5.02785509796
5.01683640459
4.91564281397
5.01801782137
5.06497876356
4.9760113271
5.02177533317
5.03372880499
5.00021407183


epoch 2 finished


4.82750659572
4.70206766446
4.84491064337
4.81180234445
4.78885094028
4.90836786491
4.8406313113
4.78400079052


nll=5.1963, ppl=180.6059, words=73760, time=8174.8702, word_per_sec=351.4515


4.80809048233
4.83197195438
4.8390549243
4.84406486121
4.75441796837
4.80943137794
4.8601280246
4.84029574248
4.85397066537
4.87411547738
4.93811393247
4.94845296827
4.82563077108
4.77974389753
4.89647271342
4.90526842008
4.84498317483
4.872687275
4.8955785678
4.82944000508


nll=5.1907, ppl=179.5978, words=73760, time=8835.2329, word_per_sec=350.0582


4.84307931232
5.01044274075
4.90021608675
4.78352179228
4.94893919472
4.85791917638
4.85416843902
4.96084296515
4.81437607311
4.92999126143
4.91094276712
4.8435356027
4.84419598282
4.87374449634
4.89756374736
4.90545188847
4.96775848607
4.91326324096
4.94277935707
4.83523429665


nll=5.1789, ppl=177.4872, words=73760, time=9580.1337, word_per_sec=346.0787


4.93106753939
4.89697291807
4.82680927088
4.89521662653
4.95381030598
4.88557286453
4.8652371724
4.90917584892
4.92080321064
4.89327156684
4.89620714557
5.00639929103
4.93253617473
4.94882165399
4.94143612206
4.86901020583
4.84016516411
4.98391814592
4.91233789675
4.87917740008


nll=5.1718, ppl=176.2258, words=73760, time=10197.6581, word_per_sec=346.6823


4.96372507539
4.95950446973
4.88961972938
4.90763167254
4.93527888489
4.92850976542
4.96134477416
4.9583327596
4.93580117182
4.86260777347
4.91244940705
4.8860988276
4.90103601586
4.86991178253
5.00066448189
4.84149971263


epoch 3 finished


4.80687811038
4.81155921894
4.75644301998
4.74821407118


nll=5.1653, ppl=175.0879, words=73760, time=10737.0382, word_per_sec=349.8818


4.71188926205
4.73072816182
4.77958321832
4.72241238893
4.81545596829
4.78612998041
4.75512575462
4.70760169908
4.75316895817
4.76525193084
4.84469466928
4.70227915235
4.70296871699
4.75268968841
4.78268377919
4.84372124714
4.77149611795
4.76996271103
4.76947084732
4.78165042415


nll=5.1784, ppl=177.4007, words=73760, time=11298.5882, word_per_sec=352.0965


4.73218762579
4.80014401729
4.747198863
4.79126945858
4.86790455179
4.76561869318
4.83957683001
4.81246307843
4.81886903976
4.8017248505
4.7647577226
4.79305894896
4.87221781529
4.70734495608
4.83874116387
4.73478614637
4.74540598692
4.76766161521
4.76566182696
