In [1]:
!pip install dynet
!git clone https://github.com/neubig/nn4nlp-code.git

fatal: destination path 'nn4nlp-code' already exists and is not an empty directory.


In [2]:
from __future__ import print_function
import time

from collections import defaultdict
import random
import math
import sys
import argparse

import dynet as dy
import numpy as np
import pdb


#some of this code borrowed from Qinlan Shen's attention from the MT class last year
#much of the beginning is the same as the text retrieval
# format of files: each line is "word1 word2 ..." aligned line-by-line
train_src_file = "nn4nlp-code/data/parallel/train.ja"
train_trg_file = "nn4nlp-code/data/parallel/train.en"
dev_src_file = "nn4nlp-code/data/parallel/dev.ja"
dev_trg_file = "nn4nlp-code/data/parallel/dev.en"

w2i_src = defaultdict(lambda: len(w2i_src))
w2i_trg = defaultdict(lambda: len(w2i_trg))


# Creates batches where all source sentences are the same length
def create_batches(sorted_dataset, max_batch_size):
    source = [x[0] for x in sorted_dataset]
    src_lengths = [len(x) for x in source]
    batches = []
    prev = src_lengths[0]
    prev_start = 0
    batch_size = 1
    for i in range(1, len(src_lengths)):
        if src_lengths[i] != prev or batch_size == max_batch_size:
            batches.append((prev_start, batch_size))
            prev = src_lengths[i]
            prev_start = i
            batch_size = 1
        else:
            batch_size += 1
    return batches


def read(fname_src, fname_trg):
    """
    Read parallel files where each line lines up
    """
    with open(fname_src, "r") as f_src, open(fname_trg, "r") as f_trg:
        for line_src, line_trg in zip(f_src, f_trg):
            #need to append EOS tags to at least the target sentence
            sent_src = [w2i_src[x] for x in line_src.strip().split() + ['</s>']] 
            sent_trg = [w2i_trg[x] for x in ['<s>'] + line_trg.strip().split() + ['</s>']] 
            yield (sent_src, sent_trg)

# Read the data
train = list(read(train_src_file, train_trg_file))
unk_src = w2i_src["<unk>"]
eos_src = w2i_src['</s>']
w2i_src = defaultdict(lambda: unk_src, w2i_src)
unk_trg = w2i_trg["<unk>"]
eos_trg = w2i_trg['</s>']
sos_trg = w2i_trg['<s>']
w2i_trg = defaultdict(lambda: unk_trg, w2i_trg)
i2w_trg = {v: k for k, v in w2i_trg.items()}

nwords_src = len(w2i_src)
nwords_trg = len(w2i_trg)
dev = list(read(dev_src_file, dev_trg_file))

# DyNet Starts
model = dy.Model()
trainer = dy.AdamTrainer(model)

# Model parameters
EMBED_SIZE = 64
HIDDEN_SIZE = 128
BATCH_SIZE = 16

#Especially in early training, the model can generate basically infinitly without generating an EOS
#have a max sent size that you end at
MAX_SENT_SIZE = 50

# Lookup parameters for word embeddings
LOOKUP_SRC = model.add_lookup_parameters((nwords_src, EMBED_SIZE))
LOOKUP_TRG = model.add_lookup_parameters((nwords_trg, EMBED_SIZE))

# Word-level LSTMs
LSTM_SRC_BUILDER = dy.LSTMBuilder(1, EMBED_SIZE, HIDDEN_SIZE, model)
LSTM_TRG_BUILDER = dy.LSTMBuilder(1, EMBED_SIZE, HIDDEN_SIZE, model)

#the softmax from the hidden size 
W_sm_p = model.add_parameters((nwords_trg, HIDDEN_SIZE))         # Weights of the softmax
b_sm_p = model.add_parameters((nwords_trg))                   # Softmax bias



def calc_loss(sents):
    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    src_sents = [x[0] for x in sents]
    tgt_sents = [x[1] for x in sents]
    src_cws = []

    src_len = [len(sent) for sent in src_sents]        
    max_src_len = np.max(src_len)
    num_words = 0

    for i in range(max_src_len):
        src_cws.append([sent[i] for sent in src_sents])


    #initialize the LSTM
    init_state_src = LSTM_SRC_BUILDER.initial_state()

    #get the output of the first LSTM
    src_output = init_state_src.add_inputs([dy.lookup_batch(LOOKUP_SRC, cws) for cws in src_cws])[-1].output()
    #now decode
    all_losses = []

    # Decoder
    #need to mask padding at end of sentence
    tgt_cws = []
    tgt_len = [len(sent) for sent in sents]
    max_tgt_len = np.max(tgt_len)
    masks = []

    for i in range(max_tgt_len):
        tgt_cws.append([sent[i] if len(sent) > i else eos_trg for sent in tgt_sents])
        mask = [(1 if len(sent) > i else 0) for sent in tgt_sents]
        masks.append(mask)
        num_words += sum(mask)



    current_state = LSTM_TRG_BUILDER.initial_state().set_s([src_output, dy.tanh(src_output)])
    prev_words = tgt_cws[0]
    W_sm = dy.parameter(W_sm_p)
    b_sm = dy.parameter(b_sm_p)

    for next_words, mask in zip(tgt_cws[1:], masks):
        #feed the current state into the 
        current_state = current_state.add_input(dy.lookup_batch(LOOKUP_TRG, prev_words))
        output_embedding = current_state.output()

        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        loss = (dy.pickneglogsoftmax_batch(s, next_words))
        mask_expr = dy.inputVector(mask)
        mask_expr = dy.reshape(mask_expr, (1,),len(sents))
        mask_loss = loss * mask_expr
        all_losses.append(mask_loss)
        prev_words = next_words
    return dy.sum_batches(dy.esum(all_losses)), num_words

def generate(sent):
    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    sent_reps = [LSTM_SRC.transduce([LOOKUP_SRC[x] for x in src])[-1] for src, trg in sents]

    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    src = sent[0]
    trg = sent[1]


    #initialize the LSTM
    init_state_src = LSTM_SRC_BUILDER.initial_state()

    #get the output of the first LSTM
    src_output = init_state_src.add_inputs([LOOKUP_SRC[x] for x in src])[-1].output()

    #generate until a eos tag or max is reached
    current_state = LSTM_TRG_BUILDER.initial_state().set_s([src_output, dy.tanh(src_output)])

    prev_word = sos_trg
    trg_sent = []
    W_sm = dy.parameter(W_sm_p)
    b_sm = dy.parameter(b_sm_p)

    for i in range(MAX_SENT_SIZE):
        #feed the previous word into the lstm, calculate the most likely word, add it to the sentence
        current_state = current_state.add_input(LOOKUP_TRG[prev_word])
        output_embedding = hidden_state.output()
        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        probs = -dy.log_softmax(s).value()
        next_word = np.argmax(probs)

        if next_word == eos_trg:
            break
        prev_word = next_word
        trg_sent.append(i2w_trg[next_word])
    return trg_sent

for ITER in range(100):
  # Perform training
  train.sort(key=lambda t: len(t[0]), reverse=True)
  dev.sort(key=lambda t: len(t[0]), reverse=True)
  train_order = create_batches(train, BATCH_SIZE) 
  dev_order = create_batches(dev, BATCH_SIZE)
  train_words, train_loss = 0, 0.0
  start = time.time()
  for sent_id, (start, length) in enumerate(train_order):
    train_batch = train[start:start+length]
    my_loss, num_words = calc_loss(train_batch)
    train_loss += my_loss.value()
    train_words += num_words
    my_loss.backward()
    trainer.update()
    if (sent_id+1) % 5000 == 0:
      print("--finished %r sentences" % (sent_id+1))
  print("iter %r: train loss/word=%.4f, ppl=%.4f, time=%.2fs" % (ITER, train_loss/train_words, math.exp(train_loss/train_words), time.time()-start))
  # Evaluate on dev set
  dev_words, dev_loss = 0, 0.0
  start = time.time()
  for sent_id, (start, length) in enumerate(dev_order):
    dev_batch = dev[start:start+length]
    my_loss, num_words = calc_loss(dev_batch)
    dev_loss += my_loss.value()
    dev_words += num_words
    trainer.update()
  print("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (ITER, dev_loss/dev_words, math.exp(dev_loss/dev_words), time.time()-start))


iter 0: train loss/word=2.2086, ppl=9.1031, time=1529670659.63s
iter 0: dev loss/word=2.2088, ppl=9.1048, time=1529680153.47s
iter 1: train loss/word=1.8026, ppl=6.0652, time=1529670689.02s
iter 1: dev loss/word=2.4381, ppl=11.4517, time=1529680182.86s
iter 2: train loss/word=1.6373, ppl=5.1410, time=1529670718.44s
iter 2: dev loss/word=2.2970, ppl=9.9448, time=1529680212.28s
iter 3: train loss/word=1.4634, ppl=4.3208, time=1529670747.85s
iter 3: dev loss/word=2.1379, ppl=8.4815, time=1529680241.70s
iter 4: train loss/word=1.3368, ppl=3.8067, time=1529670777.31s
iter 4: dev loss/word=1.9778, ppl=7.2265, time=1529680271.16s
iter 5: train loss/word=1.2323, ppl=3.4292, time=1529670806.64s
iter 5: dev loss/word=1.9710, ppl=7.1780, time=1529680300.47s
iter 6: train loss/word=1.1411, ppl=3.1302, time=1529670835.87s
iter 6: dev loss/word=2.0018, ppl=7.4024, time=1529680329.72s
iter 7: train loss/word=1.0633, ppl=2.8959, time=1529670865.17s
iter 7: dev loss/word=2.0171, ppl=7.5162, time=152968

iter 27: dev loss/word=3.0381, ppl=20.8658, time=1529680927.84s
iter 28: train loss/word=0.1206, ppl=1.1281, time=1529671462.46s
iter 28: dev loss/word=3.1096, ppl=22.4112, time=1529680956.22s
iter 29: train loss/word=0.1045, ppl=1.1102, time=1529671490.64s
iter 29: dev loss/word=3.1077, ppl=22.3691, time=1529680984.39s
iter 30: train loss/word=0.0921, ppl=1.0965, time=1529671518.88s
iter 30: dev loss/word=3.1497, ppl=23.3287, time=1529681012.64s
iter 31: train loss/word=0.0769, ppl=1.0800, time=1529671547.13s
iter 31: dev loss/word=3.2458, ppl=25.6833, time=1529681040.88s
iter 32: train loss/word=0.0748, ppl=1.0777, time=1529671575.43s
iter 32: dev loss/word=3.2386, ppl=25.4974, time=1529681069.18s
iter 33: train loss/word=0.0654, ppl=1.0676, time=1529671603.71s
iter 33: dev loss/word=3.2893, ppl=26.8246, time=1529681097.47s
iter 34: train loss/word=0.0578, ppl=1.0595, time=1529671632.53s
iter 34: dev loss/word=3.3192, ppl=27.6390, time=1529681126.33s
iter 35: train loss/word=0.0520, 

iter 55: train loss/word=0.0087, ppl=1.0087, time=1529672226.35s
iter 55: dev loss/word=4.0445, ppl=57.0842, time=1529681720.12s
iter 56: train loss/word=0.0083, ppl=1.0084, time=1529672254.59s
iter 56: dev loss/word=3.9705, ppl=53.0090, time=1529681748.35s
iter 57: train loss/word=0.0077, ppl=1.0078, time=1529672282.87s
iter 57: dev loss/word=3.9548, ppl=52.1850, time=1529681776.63s
iter 58: train loss/word=0.0099, ppl=1.0100, time=1529672311.22s
iter 58: dev loss/word=4.0250, ppl=55.9782, time=1529681804.96s
iter 59: train loss/word=0.0088, ppl=1.0089, time=1529672339.45s
iter 59: dev loss/word=4.0005, ppl=54.6264, time=1529681833.20s
iter 60: train loss/word=0.0075, ppl=1.0075, time=1529672367.65s
iter 60: dev loss/word=3.8663, ppl=47.7655, time=1529681861.41s
iter 61: train loss/word=0.0085, ppl=1.0085, time=1529672395.90s
iter 61: dev loss/word=3.9915, ppl=54.1383, time=1529681889.64s
iter 62: train loss/word=0.0065, ppl=1.0065, time=1529672424.03s
iter 62: dev loss/word=4.0016, p

iter 82: dev loss/word=4.4591, ppl=86.4120, time=1529682484.49s
iter 83: train loss/word=0.0043, ppl=1.0043, time=1529673019.13s
iter 83: dev loss/word=4.4117, ppl=82.4130, time=1529682512.89s
iter 84: train loss/word=0.0053, ppl=1.0053, time=1529673047.47s
iter 84: dev loss/word=4.4746, ppl=87.7586, time=1529682541.23s
iter 85: train loss/word=0.0039, ppl=1.0039, time=1529673075.97s
iter 85: dev loss/word=4.4763, ppl=87.9097, time=1529682569.73s
iter 86: train loss/word=0.0043, ppl=1.0043, time=1529673104.45s
iter 86: dev loss/word=4.5566, ppl=95.2595, time=1529682598.20s
iter 87: train loss/word=0.0039, ppl=1.0039, time=1529673132.85s
iter 87: dev loss/word=4.6017, ppl=99.6544, time=1529682626.60s
iter 88: train loss/word=0.0083, ppl=1.0083, time=1529673161.22s
iter 88: dev loss/word=4.3905, ppl=80.6821, time=1529682654.98s
iter 89: train loss/word=0.0042, ppl=1.0042, time=1529673189.57s
iter 89: dev loss/word=4.5781, ppl=97.3312, time=1529682683.32s
iter 90: train loss/word=0.0071, 