In [1]:
!pip install torch torchtext
!git clone https://github.com/neubig/nn4nlp-code.git

fatal: destination path 'nn4nlp-code' already exists and is not an empty directory.


In [2]:
from __future__ import print_function
import time

from collections import defaultdict
import random
import math
import sys
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pdb

In [3]:
#some of this code borrowed from Qinlan Shen's attention from the MT class last year
#much of the beginning is the same as the text retrieval
# format of files: each line is "word1 word2 ..." aligned line-by-line
train_src_file = "nn4nlp-code/data/parallel/train.ja"
train_trg_file = "nn4nlp-code/data/parallel/train.en"
dev_src_file = "nn4nlp-code/data/parallel/dev.ja"
dev_trg_file = "nn4nlp-code/data/parallel/dev.en"

w2i_src = defaultdict(lambda: len(w2i_src))
w2i_trg = defaultdict(lambda: len(w2i_trg))

def read(fname_src, fname_trg):
    """
    Read parallel files where each line lines up
    """
    with open(fname_src, "r") as f_src, open(fname_trg, "r") as f_trg:
        for line_src, line_trg in zip(f_src, f_trg):
            #need to append EOS tags to at least the target sentence
            sent_src = [w2i_src[x] for x in line_src.strip().split() + ['</s>']] 
            sent_trg = [w2i_trg[x] for x in ['<s>'] + line_trg.strip().split() + ['</s>']] 
            yield (sent_src, sent_trg)

# Read the data
train = list(read(train_src_file, train_trg_file))
unk_src = w2i_src["<unk>"]
eos_src = w2i_src['</s>']
w2i_src = defaultdict(lambda: unk_src, w2i_src)
unk_trg = w2i_trg["<unk>"]
eos_trg = w2i_trg['</s>']
sos_trg = w2i_trg['<s>']
w2i_trg = defaultdict(lambda: unk_trg, w2i_trg)
i2w_trg = {v: k for k, v in w2i_trg.items()}

nwords_src = len(w2i_src)
nwords_trg = len(w2i_trg)
dev = list(read(dev_src_file, dev_trg_file))

In [4]:
# Model parameters
EMBED_SIZE = 64
HIDDEN_SIZE = 128
BATCH_SIZE = 16

In [5]:
from torch.utils.data import Dataset, DataLoader

In [13]:
class ParallelCorpus(Dataset):
  def __init__(self, data):
    self.data = data
    
  def __len__(self):
    return len(self.data)
    
  def __getitem__(self, ix):
    return torch.LongTensor(self.data[ix][0]), torch.LongTensor(self.data[ix][1])
  
def my_collate_fn(batch):
  src, trg = zip(*batch)
  src_len, trg_len = list(map(len, src)), list(map(len, trg))
  src_maxlen, trg_maxlen = max(src_len), max(trg_len)
  
  src = torch.stack([F.pad(e, (0, src_maxlen-len(e))) for e in src])
  trg = torch.stack([F.pad(e, (0, trg_maxlen-len(e))) for e in trg])
  
  return src, trg, torch.LongTensor(src_len), torch.LongTensor(trg_len)

# my_collate_fn([train_corpus[i] for i in range(4)])

In [14]:
train_corpus = ParallelCorpus(train)
train_loader = DataLoader(train_corpus, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, collate_fn=my_collate_fn)

dev_corpus = ParallelCorpus(dev)
dev_loader = DataLoader(dev_corpus, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, collate_fn=my_collate_fn)

In [15]:
#Especially in early training, the model can generate basically infinitly without generating an EOS
#have a max sent size that you end at
MAX_SENT_SIZE = 50

In [16]:
class EncoderRNN(nn.Module):
  def __init__(self, input_size, embed_size, hidden_size):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    
    self.embedding = nn.Embedding(input_size, embed_size)
    self.gru = nn.GRU(embed_size, hidden_size, batch_first=True)
  
  def forward(self, x, x_len):
    h0 = self.init_hidden(x_len.shape[0])
    encoded = self.embedding(x)
    output, _ = self.gru(encoded, h0)
    return torch.stack([output[i, x_len[i]-1] for i in range(len(x_len))])
    
  def init_hidden(self, bs):
    return torch.zeros(1, bs, self.hidden_size, device=device)
  
class DecoderRNN(nn.Module):
  def __init__(self, output_size, embed_size, hidden_size):
    super(DecoderRNN, self).__init__()
    self.hidden_size = hidden_size
    
    self.embedding = nn.Embedding(output_size, embed_size)
    self.gru = nn.GRU(embed_size, hidden_size, batch_first=True)
    self.out = nn.Linear(hidden_size, output_size)
  
  def forward(self, x, x_len, hidden):
    output = self.embedding(x)
    output, hidden = self.gru(output, hidden)
    output = self.out(output)
    return output, hidden

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = EncoderRNN(nwords_src, EMBED_SIZE, HIDDEN_SIZE).to(device)
decoder = DecoderRNN(nwords_trg, EMBED_SIZE, HIDDEN_SIZE).to(device)

In [26]:
criterion = nn.CrossEntropyLoss()
trainer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

In [27]:
for epoch_i in range(20):
  encoder.train()
  decoder.train()
  total_loss = 0.
  for batch_i, (s, t, sl, tl) in enumerate(train_loader):
    s, t, si, ti = s.to(device), t.to(device), sl.to(device), tl.to(device)
    bs = s.shape[0]
    t_in = t[:,:-1]
    t_out = t[:, 1:]
    encoded = encoder(s, sl)
    decoded, hidden = decoder(t_in, tl, encoded.unsqueeze(0))
    decoded = torch.cat([decoded[ix, :tl[ix]-1].view(-1, nwords_trg) for ix in range(bs)], 0)
    t_out = torch.cat([t_out[ix, :tl[ix]-1].view(-1) for ix in range(bs)], 0)
    loss = criterion(decoded, t_out)
    total_loss += loss.item()

    trainer.zero_grad()
    loss.backward()
    trainer.step()
  
  print("epoch {} | train loss {:5.4f}".format(epoch_i, total_loss / len(train_loader)))
  
  encoder.eval()
  decoder.eval()
  dev_loss = 0.
  for batch_i, (s, t, sl, tl) in enumerate(dev_loader):
    s, t, si, ti = s.to(device), t.to(device), sl.to(device), tl.to(device)
    bs = s.shape[0]
    t_in = t[:,:-1]
    t_out = t[:, 1:]
    encoded = encoder(s, sl)
    decoded, hidden = decoder(t_in, tl, encoded.unsqueeze(0))
    decoded = torch.cat([decoded[ix, :tl[ix]-1].view(-1, nwords_trg) for ix in range(bs)], 0)
    t_out = torch.cat([t_out[ix, :tl[ix]-1].view(-1) for ix in range(bs)], 0)
    loss = criterion(decoded, t_out)
    dev_loss += loss.item()
  print("epoch {} | val loss {:5.4f}".format(epoch_i, dev_loss / len(dev_loader)))

epoch 0 | train loss 5.3064
epoch 0 | val loss 4.8202
epoch 1 | train loss 4.4170
epoch 1 | val loss 4.5080
epoch 2 | train loss 4.0233
epoch 2 | val loss 4.3865
epoch 3 | train loss 3.7268
epoch 3 | val loss 4.3373
epoch 4 | train loss 3.4771
epoch 4 | val loss 4.3153
epoch 5 | train loss 3.2524
epoch 5 | val loss 4.3187
epoch 6 | train loss 3.0453
epoch 6 | val loss 4.3423
epoch 7 | train loss 2.8557
epoch 7 | val loss 4.3697
epoch 8 | train loss 2.6793
epoch 8 | val loss 4.4174
epoch 9 | train loss 2.5186
epoch 9 | val loss 4.4596
epoch 10 | train loss 2.3700
epoch 10 | val loss 4.5049
epoch 11 | train loss 2.2313
epoch 11 | val loss 4.5617
epoch 12 | train loss 2.1050
epoch 12 | val loss 4.6232
epoch 13 | train loss 1.9879
epoch 13 | val loss 4.6945
epoch 14 | train loss 1.8748
epoch 14 | val loss 4.7524
epoch 15 | train loss 1.7748
epoch 15 | val loss 4.8228
epoch 16 | train loss 1.6802
epoch 16 | val loss 4.9006
epoch 17 | train loss 1.5937
epoch 17 | val loss 4.9615
epoch 18 | t

epoch 9 | loss 4.4261


In [0]:
def calc_loss(sents):
    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    src_sents = [x[0] for x in sents]
    tgt_sents = [x[1] for x in sents]
    src_cws = []

    src_len = [len(sent) for sent in src_sents]        
    max_src_len = np.max(src_len)
    num_words = 0

    for i in range(max_src_len):
        src_cws.append([sent[i] for sent in src_sents])


    #initialize the LSTM
    init_state_src = LSTM_SRC_BUILDER.initial_state()

    #get the output of the first LSTM
    src_output = init_state_src.add_inputs([dy.lookup_batch(LOOKUP_SRC, cws) for cws in src_cws])[-1].output()
    #now decode
    all_losses = []

    # Decoder
    #need to mask padding at end of sentence
    tgt_cws = []
    tgt_len = [len(sent) for sent in sents]
    max_tgt_len = np.max(tgt_len)
    masks = []

    for i in range(max_tgt_len):
        tgt_cws.append([sent[i] if len(sent) > i else eos_trg for sent in tgt_sents])
        mask = [(1 if len(sent) > i else 0) for sent in tgt_sents]
        masks.append(mask)
        num_words += sum(mask)



    current_state = LSTM_TRG_BUILDER.initial_state().set_s([src_output, dy.tanh(src_output)])
    prev_words = tgt_cws[0]
    W_sm = dy.parameter(W_sm_p)
    b_sm = dy.parameter(b_sm_p)

    for next_words, mask in zip(tgt_cws[1:], masks):
        #feed the current state into the 
        current_state = current_state.add_input(dy.lookup_batch(LOOKUP_TRG, prev_words))
        output_embedding = current_state.output()

        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        loss = (dy.pickneglogsoftmax_batch(s, next_words))
        mask_expr = dy.inputVector(mask)
        mask_expr = dy.reshape(mask_expr, (1,),len(sents))
        mask_loss = loss * mask_expr
        all_losses.append(mask_loss)
        prev_words = next_words
    return dy.sum_batches(dy.esum(all_losses)), num_words

In [0]:
def generate(sent):
    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    sent_reps = [LSTM_SRC.transduce([LOOKUP_SRC[x] for x in src])[-1] for src, trg in sents]

    dy.renew_cg()

    # Transduce all batch elements with an LSTM
    src = sent[0]
    trg = sent[1]


    #initialize the LSTM
    init_state_src = LSTM_SRC_BUILDER.initial_state()

    #get the output of the first LSTM
    src_output = init_state_src.add_inputs([LOOKUP_SRC[x] for x in src])[-1].output()

    #generate until a eos tag or max is reached
    current_state = LSTM_TRG_BUILDER.initial_state().set_s([src_output, dy.tanh(src_output)])

    prev_word = sos_trg
    trg_sent = []
    W_sm = dy.parameter(W_sm_p)
    b_sm = dy.parameter(b_sm_p)

    for i in range(MAX_SENT_SIZE):
        #feed the previous word into the lstm, calculate the most likely word, add it to the sentence
        current_state = current_state.add_input(LOOKUP_TRG[prev_word])
        output_embedding = hidden_state.output()
        s = dy.affine_transform([b_sm, W_sm, output_embedding])
        probs = -dy.log_softmax(s).value()
        next_word = np.argmax(probs)

        if next_word == eos_trg:
            break
        prev_word = next_word
        trg_sent.append(i2w_trg[next_word])
    return trg_sent



In [0]:
for ITER in range(100):
  # Perform training
  train.sort(key=lambda t: len(t[0]), reverse=True)
  dev.sort(key=lambda t: len(t[0]), reverse=True)
  train_order = create_batches(train, BATCH_SIZE) 
  dev_order = create_batches(dev, BATCH_SIZE)
  train_words, train_loss = 0, 0.0
  start = time.time()
  for sent_id, (start, length) in enumerate(train_order):
    train_batch = train[start:start+length]
    my_loss, num_words = calc_loss(train_batch)
    train_loss += my_loss.value()
    train_words += num_words
    my_loss.backward()
    trainer.update()
    if (sent_id+1) % 5000 == 0:
      print("--finished %r sentences" % (sent_id+1))
  print("iter %r: train loss/word=%.4f, ppl=%.4f, time=%.2fs" % (ITER, train_loss/train_words, math.exp(train_loss/train_words), time.time()-start))
  # Evaluate on dev set
  dev_words, dev_loss = 0, 0.0
  start = time.time()
  for sent_id, (start, length) in enumerate(dev_order):
    dev_batch = dev[start:start+length]
    my_loss, num_words = calc_loss(dev_batch)
    dev_loss += my_loss.value()
    dev_words += num_words
    trainer.update()
  print("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (ITER, dev_loss/dev_words, math.exp(dev_loss/dev_words), time.time()-start))