In [1]:
!pip install torch torchtext
!git clone https://github.com/neubig/nn4nlp-code.git

fatal: destination path 'nn4nlp-code' already exists and is not an empty directory.


In [2]:
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [3]:
train_src_file = "nn4nlp-code/data/parallel/train.ja"
train_trg_file = "nn4nlp-code/data/parallel/train.en"
dev_src_file = "nn4nlp-code/data/parallel/dev.ja"
dev_trg_file = "nn4nlp-code/data/parallel/dev.en"
test_src_file = "nn4nlp-code/data/parallel/test.ja"
test_trg_file = "nn4nlp-code/data/parallel/test.en"

In [4]:
w2i_src = defaultdict(lambda: len(w2i_src))
w2i_trg = defaultdict(lambda: len(w2i_trg))

def read(fname_src, fname_trg):
    """
    Read parallel files where each line lines up
    """
    with open(fname_src, "r") as f_src, open(fname_trg, "r") as f_trg:
        for line_src, line_trg in zip(f_src, f_trg):
            #need to append EOS tags to at least the target sentence
            sent_src = [w2i_src[x] for x in line_src.strip().split() + ['</s>']] 
            sent_trg = [w2i_trg[x] for x in ['<s>'] + line_trg.strip().split() + ['</s>']] 
            yield (sent_src, sent_trg)

# Read the data
train = list(read(train_src_file, train_trg_file))
unk_src = w2i_src["<unk>"]
eos_src = w2i_src['</s>']
w2i_src = defaultdict(lambda: unk_src, w2i_src)
unk_trg = w2i_trg["<unk>"]
eos_trg = w2i_trg['</s>']
sos_trg = w2i_trg['<s>']
w2i_trg = defaultdict(lambda: unk_trg, w2i_trg)
i2w_trg = {v: k for k, v in w2i_trg.items()}
i2w_src = {v: k for k, v in w2i_src.items()}

nwords_src = len(w2i_src)
nwords_trg = len(w2i_trg)
dev = list(read(dev_src_file, dev_trg_file))
test = list(read(test_src_file, test_trg_file))

In [5]:
class ParallelCorpus(Dataset):
  def __init__(self, data):
    self.data = data
    
  def __len__(self):
    return len(self.data)
    
  def __getitem__(self, ix):
    return torch.LongTensor(self.data[ix][0]), torch.LongTensor(self.data[ix][1])
  
def my_collate_fn(batch):
  src, trg = zip(*batch)
  src_len, trg_len = list(map(len, src)), list(map(len, trg))
  src_maxlen, trg_maxlen = max(src_len), max(trg_len)
  
  src = torch.stack([F.pad(e, (0, src_maxlen-len(e))) for e in src])
  trg = torch.stack([F.pad(e, (0, trg_maxlen-len(e))) for e in trg])
  
  return src, trg, torch.LongTensor(src_len), torch.LongTensor(trg_len)

# my_collate_fn([train_corpus[i] for i in range(4)])

In [6]:
class EncoderRNN(nn.Module):
  def __init__(self, input_size, embed_size, hidden_size):
    super(EncoderRNN, self).__init__()
    self.hidden_size = hidden_size
    
    self.embedding = nn.Embedding(input_size, embed_size)
    self.gru = nn.GRU(embed_size, hidden_size, batch_first=True)
  
  def forward(self, x, x_len):
    h0 = self.init_hidden(x_len.shape[0])
    encoded = self.embedding(x)
    output, _ = self.gru(encoded, h0)
    return output
    
  def init_hidden(self, bs):
    return torch.zeros(1, bs, self.hidden_size, device=device)

In [7]:
# class Attention(nn.Module):
#   def __init__(self, method, hidden_dim):
#     super(Attention, self).__init__()
    
#     self.method = method
#     self.hidden_dim = hidden_dim
    
#     if self.method == 'general':
#       self.attn = nn.Linear(self.hidden_size, hidden_size)
#     elif self.method == 'concat':
#       self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
#       self.v = nn.Parameter(torch.FloatTensor(1, hidden_dim))
    
#   def forward(self, query, value, mask=None):
#     # query: bs x maxlen x hidden_dim
#     score = self.score(query, value)
#     if mask is not None:
#       score = score.mask_fill(mask == 0, -1e9)
#     p_attn = F.softmax(score, dim=-1)
    
#     return torch.matmul(p_attn, value), p_attn
  
#   def score(self, query, value):
#     if self.method == 'dot':
#       return torch.matmul(query, value.transpose(-2, -1))
#     elif self.method == 'general':
#       return torch.matmul(query, self.attn(value).transpose(-2, -1))
#     elif self.method == 'concat':
#       return torch.matmul(self.v, F.tanh(self.attn(torch.cat((query, value), 1)).transpose(-2, -1)))

In [8]:
class Attention(nn.Module):
    def __init__(self, dim):
        super(Attention, self).__init__()
        self.linear = nn.Linear(dim*2, dim, bias=False)
        
    def forward(self, x, context):
        attn = F.softmax(
            context.bmm(
                x.unsqueeze(2) # bsz x dim x 1
            )                  # bsz x seq x 1
            .squeeze(2)        # bsz x seq
            , dim = 1)
        weighted_context = attn.unsqueeze(1) # bsz x 1 x seq 
        weighted_context = weighted_context.bmm(context)             # bsz x 1 x dim
        weighted_context = weighted_context.squeeze(1)               # bsz x dim
        o = self.linear(torch.cat((x, weighted_context), 1))
        return F.tanh(o)

In [15]:
class BahdanauAttnDecoderRNN(nn.Module):
  def __init__(self, output_size, embed_size, hidden_size):
    super(BahdanauAttnDecoderRNN, self).__init__()
    self.hidden_size = hidden_size
    
    self.embedding = nn.Embedding(output_size, embed_size)
    self.gru = nn.GRU(embed_size, hidden_size, batch_first=True)
#     self.attn = Attention('concat', hidden_size)
    self.attn = Attention(hidden_size)
    self.out = nn.Linear(hidden_size, output_size)
  
  def init_hidden(self, bs):
    return torch.zeros(1, bs, self.hidden_size, device=device)
  
  def forward(self, xs, xs_len, context):
    xs = xs.transpose(0, 1)
    bs = xs.size(1)
    o = []
    embeds = self.embedding(xs)
    hidden = self.init_hidden(bs)
    for emb in embeds:
      res, hidden = self.gru(emb.unsqueeze(1), hidden)
      o.append(self.attn(res.squeeze(1), context))
    output = self.out(torch.stack(o))
    return output.transpose(0, 1)

In [16]:
# Model parameters
EMBED_SIZE = 64
HIDDEN_SIZE = 128
BATCH_SIZE = 16

In [17]:
train_corpus = ParallelCorpus(train)
train_loader = DataLoader(train_corpus, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, collate_fn=my_collate_fn)

dev_corpus = ParallelCorpus(dev)
dev_loader = DataLoader(dev_corpus, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, collate_fn=my_collate_fn)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = EncoderRNN(nwords_src, EMBED_SIZE, HIDDEN_SIZE).to(device)
decoder = BahdanauAttnDecoderRNN(nwords_trg, EMBED_SIZE, HIDDEN_SIZE).to(device)

In [19]:
criterion = nn.CrossEntropyLoss()
trainer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

In [22]:
for epoch_i in range(20):
  encoder.train()
  decoder.train()
  total_loss = 0.
  for batch_i, (s, t, sl, tl) in enumerate(train_loader):
    s, t, si, ti = s.to(device), t.to(device), sl.to(device), tl.to(device)
    bs = s.shape[0]
    t_in = t[:,:-1]
    t_out = t[:, 1:]
    encoded = encoder(s, sl)
    decoded = decoder(t_in, tl, encoded)
    decoded = torch.cat([decoded[ix, :tl[ix]-1].view(-1, nwords_trg) for ix in range(bs)], 0)
    t_out = torch.cat([t_out[ix, :tl[ix]-1].view(-1) for ix in range(bs)], 0)
    loss = criterion(decoded, t_out)
    total_loss += loss.item()

    trainer.zero_grad()
    loss.backward()
    trainer.step()
  
  print("epoch {} | train loss {:5.4f}".format(epoch_i, total_loss / len(train_loader)))
  
  encoder.eval()
  decoder.eval()
  dev_loss = 0.
  for batch_i, (s, t, sl, tl) in enumerate(dev_loader):
    s, t, si, ti = s.to(device), t.to(device), sl.to(device), tl.to(device)
    bs = s.shape[0]
    t_in = t[:,:-1]
    t_out = t[:, 1:]
    encoded = encoder(s, sl)
    decoded = decoder(t_in, tl, encoded)
    decoded = torch.cat([decoded[ix, :tl[ix]-1].view(-1, nwords_trg) for ix in range(bs)], 0)
    t_out = torch.cat([t_out[ix, :tl[ix]-1].view(-1) for ix in range(bs)], 0)
    loss = criterion(decoded, t_out)
    dev_loss += loss.item()
  print("epoch {} | val loss {:5.4f}".format(epoch_i, dev_loss / len(dev_loader)))

epoch 0 | train loss 4.1152
epoch 0 | val loss 4.4178
epoch 1 | train loss 3.8242
epoch 1 | val loss 4.3185
epoch 2 | train loss 3.5713
epoch 2 | val loss 4.2843
epoch 3 | train loss 3.3382
epoch 3 | val loss 4.3012
epoch 4 | train loss 3.1197
epoch 4 | val loss 4.3102
epoch 5 | train loss 2.9122
epoch 5 | val loss 4.3525
epoch 6 | train loss 2.7126
epoch 6 | val loss 4.3840
epoch 7 | train loss 2.5219
epoch 7 | val loss 4.4362
epoch 8 | train loss 2.3444
epoch 8 | val loss 4.4829
epoch 9 | train loss 2.1730
epoch 9 | val loss 4.5617
epoch 10 | train loss 2.0136
epoch 10 | val loss 4.6452
epoch 11 | train loss 1.8651
epoch 11 | val loss 4.7111
epoch 12 | train loss 1.7273
epoch 12 | val loss 4.7986
epoch 13 | train loss 1.5958
epoch 13 | val loss 4.8795
epoch 14 | train loss 1.4759
epoch 14 | val loss 4.9750
epoch 15 | train loss 1.3665
epoch 15 | val loss 5.0246
epoch 16 | train loss 1.2590
epoch 16 | val loss 5.1462
epoch 17 | train loss 1.1625
epoch 17 | val loss 5.2381
epoch 18 | t