In [19]:
!pip install torch
!git clone https://github.com/neubig/nn4nlp-code.git

fatal: destination path 'nn4nlp-code' already exists and is not an empty directory.


In [0]:
from __future__ import print_function
import time

start = time.time()

from collections import Counter, defaultdict
import random
import math
import sys
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools

In [0]:
# format of files: each line is "word1 word2 ..."
train_file = "nn4nlp-code/data/ptb/train.txt"
test_file = "nn4nlp-code/data/ptb/valid.txt"

w2i = defaultdict(lambda: len(w2i))


def read(fname):
    """
    Read a file where each line is of the form "word1 word2 ..."
    Yields lists of the form [word1, word2, ...]
    """
    with open(fname, "r") as fh:
        for line in fh:
            sent = [w2i[x] for x in line.strip().split()]
            sent.append(w2i["<s>"])
            yield sent


train = list(read(train_file))
nwords = len(w2i)
test = list(read(test_file))
S = w2i["<s>"]
assert (nwords == len(w2i))

In [0]:
def batchify(data, bsz):
  nbatch = data.size(0) // bsz
  data = data.narrow(0, 0, nbatch * bsz)
  data = data.view(bsz, -1).t().contiguous()
  return data.to(device)

In [23]:
batch_size = 32
bptt_size = 64

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_flatten = torch.LongTensor(list(itertools.chain.from_iterable(train)))
test_flatten = torch.LongTensor(list(itertools.chain.from_iterable(test)))

train_data = batchify(train_flatten, batch_size)
test_data = batchify(test_flatten, batch_size)

train_data.shape, test_data.shape

(torch.Size([29049, 32]), torch.Size([2305, 32]))

In [0]:
class LanguageModel(nn.Module):
  def __init__(self, nwords, emb_size, hidden_size):
    super(LanguageModel, self).__init__()
    self.nwords = nwords
    self.emb_size = emb_size
    self.hidden_size = hidden_size
    
    self.encoder = nn.Embedding(nwords, emb_size)
    self.rnn = nn.LSTM(emb_size, hidden_size)
    self.decoder = nn.Linear(hidden_size, nwords)
    
  def forward(self, x, hidden):
    emb = self.encoder(x)
    output, hidden = self.rnn(emb, hidden)
    decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
    return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

In [0]:
EMBED_SIZE = 64
HIDDEN_SIZE = 128
log_interval = 200

In [0]:
model = LanguageModel(nwords, EMBED_SIZE, HIDDEN_SIZE).to(device)
trainer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [0]:
def repackage_hidden(h):
  if isinstance(h, torch.Tensor):
    return h.detach()
  else:
    return tuple(repackage_hidden(v) for v in h)

def get_batch(source, i):
  seq_len = min(bptt_size, len(source) - 1 - i)
  data = source[i:i+seq_len]
  target = source[i+1:i+1+seq_len].view(-1)
  return data, target

In [0]:
def train():
  model.train()
  total_loss = 0.
  start_time = time.time()
  hidden = (torch.zeros(1, batch_size, HIDDEN_SIZE).to(device), torch.zeros(1, batch_size, HIDDEN_SIZE).to(device))

  for batch_ix, i in enumerate(range(0, train_data.size(0) - 1, bptt_size)):
    data, targets = get_batch(train_data, i)
    hidden = repackage_hidden(hidden)
    model.zero_grad()
    output, hidden = model(data, hidden)
    loss = criterion(output.view(-1, nwords), targets)
    loss.backward()

    trainer.step()

    total_loss += loss.item()

    if batch_ix % log_interval == 0 and batch_ix > 0:
      cur_loss = total_loss / log_interval
      elapsed = time.time() - start_time
      print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
              'loss {:5.2f} | ppl {:8.2f}'.format(
          epoch, batch_ix, len(train_data) // bptt_size, 
          elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
      total_loss = 0
      start_time = time.time()

In [0]:
def evaluate(data_source):
  model.eval()
  total_loss = 0.
  hidden = (torch.zeros(1, batch_size, HIDDEN_SIZE).to(device), torch.zeros(1, batch_size, HIDDEN_SIZE).to(device))
  with torch.no_grad():
    for i in range(0, data_source.size(0) - 1, bptt_size):
      data, targets = get_batch(data_source, i)
      output, hidden = model(data, hidden)
      output_flat = output.view(-1, nwords)
      total_loss += len(data) * criterion(output_flat, targets).item()
      hidden = repackage_hidden(hidden)
  
  return total_loss / len(data_source)

In [34]:
for epoch in range(20):
  epoch_start_time = time.time()
  train()
  val_loss = evaluate(test_data)
  print('-' * 89)
  print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
  print('-' * 89)

| epoch   0 |   200/  453 batches | ms/batch 29.06 | loss  5.50 | ppl   244.28
| epoch   0 |   400/  453 batches | ms/batch 28.16 | loss  5.35 | ppl   211.55
-----------------------------------------------------------------------------------------
| end of epoch   0 | time: 13.32s | valid loss  5.43 | valid ppl   228.00
-----------------------------------------------------------------------------------------
| epoch   1 |   200/  453 batches | ms/batch 28.26 | loss  5.36 | ppl   211.96
| epoch   1 |   400/  453 batches | ms/batch 28.19 | loss  5.23 | ppl   185.95
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 13.18s | valid loss  5.35 | valid ppl   210.28
-----------------------------------------------------------------------------------------
| epoch   2 |   200/  453 batches | ms/batch 28.30 | loss  5.24 | ppl   189.32
| epoch   2 |   400/  453 batches | ms/batch 28.11 | loss  5.12 | ppl   167.38
------------------

-----------------------------------------------------------------------------------------
| end of epoch  12 | time: 13.15s | valid loss  5.06 | valid ppl   158.19
-----------------------------------------------------------------------------------------
| epoch  13 |   200/  453 batches | ms/batch 28.30 | loss  4.60 | ppl    99.74
| epoch  13 |   400/  453 batches | ms/batch 28.21 | loss  4.50 | ppl    90.39
-----------------------------------------------------------------------------------------
| end of epoch  13 | time: 13.18s | valid loss  5.06 | valid ppl   157.59
-----------------------------------------------------------------------------------------
| epoch  14 |   200/  453 batches | ms/batch 28.25 | loss  4.57 | ppl    96.26
| epoch  14 |   400/  453 batches | ms/batch 28.09 | loss  4.47 | ppl    87.32
-----------------------------------------------------------------------------------------
| end of epoch  14 | time: 13.14s | valid loss  5.06 | valid ppl   157.23
------------

| epoch  25 |   200/  453 batches | ms/batch 28.29 | loss  4.27 | ppl    71.39
| epoch  25 |   400/  453 batches | ms/batch 28.04 | loss  4.18 | ppl    65.10
-----------------------------------------------------------------------------------------
| end of epoch  25 | time: 13.14s | valid loss  5.11 | valid ppl   164.98
-----------------------------------------------------------------------------------------
| epoch  26 |   200/  453 batches | ms/batch 28.18 | loss  4.25 | ppl    69.87
| epoch  26 |   400/  453 batches | ms/batch 28.04 | loss  4.15 | ppl    63.73
-----------------------------------------------------------------------------------------
| end of epoch  26 | time: 13.11s | valid loss  5.11 | valid ppl   166.45
-----------------------------------------------------------------------------------------
| epoch  27 |   200/  453 batches | ms/batch 28.21 | loss  4.23 | ppl    68.42
| epoch  27 |   400/  453 batches | ms/batch 28.09 | loss  4.13 | ppl    62.45
------------------

-----------------------------------------------------------------------------------------
| end of epoch  37 | time: 13.01s | valid loss  5.23 | valid ppl   186.27
-----------------------------------------------------------------------------------------
| epoch  38 |   200/  453 batches | ms/batch 28.02 | loss  4.03 | ppl    56.45
| epoch  38 |   400/  453 batches | ms/batch 27.88 | loss  3.95 | ppl    51.70
-----------------------------------------------------------------------------------------
| end of epoch  38 | time: 13.03s | valid loss  5.24 | valid ppl   188.37
-----------------------------------------------------------------------------------------
| epoch  39 |   200/  453 batches | ms/batch 28.08 | loss  4.02 | ppl    55.60
| epoch  39 |   400/  453 batches | ms/batch 27.94 | loss  3.93 | ppl    50.94
-----------------------------------------------------------------------------------------
| end of epoch  39 | time: 13.06s | valid loss  5.25 | valid ppl   190.57
------------

| epoch  50 |   200/  453 batches | ms/batch 28.19 | loss  3.88 | ppl    48.34
| epoch  50 |   400/  453 batches | ms/batch 28.01 | loss  3.79 | ppl    44.32
-----------------------------------------------------------------------------------------
| end of epoch  50 | time: 13.10s | valid loss  5.38 | valid ppl   217.96
-----------------------------------------------------------------------------------------
| epoch  51 |   200/  453 batches | ms/batch 28.06 | loss  3.87 | ppl    47.79
| epoch  51 |   400/  453 batches | ms/batch 28.09 | loss  3.78 | ppl    43.82
-----------------------------------------------------------------------------------------
| end of epoch  51 | time: 13.09s | valid loss  5.39 | valid ppl   220.29
-----------------------------------------------------------------------------------------
| epoch  52 |   200/  453 batches | ms/batch 28.09 | loss  3.86 | ppl    47.27
| epoch  52 |   400/  453 batches | ms/batch 27.94 | loss  3.77 | ppl    43.33
------------------

-----------------------------------------------------------------------------------------
| end of epoch  62 | time: 13.14s | valid loss  5.53 | valid ppl   252.69
-----------------------------------------------------------------------------------------
| epoch  63 |   200/  453 batches | ms/batch 28.25 | loss  3.74 | ppl    42.15
| epoch  63 |   400/  453 batches | ms/batch 28.10 | loss  3.66 | ppl    38.95
-----------------------------------------------------------------------------------------
| end of epoch  63 | time: 13.14s | valid loss  5.54 | valid ppl   255.76
-----------------------------------------------------------------------------------------
| epoch  64 |   200/  453 batches | ms/batch 28.25 | loss  3.73 | ppl    41.78
| epoch  64 |   400/  453 batches | ms/batch 28.12 | loss  3.65 | ppl    38.61
-----------------------------------------------------------------------------------------
| end of epoch  64 | time: 13.15s | valid loss  5.56 | valid ppl   258.92
------------

| epoch  75 |   200/  453 batches | ms/batch 28.25 | loss  3.64 | ppl    38.15
| epoch  75 |   400/  453 batches | ms/batch 28.15 | loss  3.57 | ppl    35.48
-----------------------------------------------------------------------------------------
| end of epoch  75 | time: 13.16s | valid loss  5.68 | valid ppl   291.81
-----------------------------------------------------------------------------------------
| epoch  76 |   200/  453 batches | ms/batch 28.28 | loss  3.63 | ppl    37.87
| epoch  76 |   400/  453 batches | ms/batch 28.19 | loss  3.56 | ppl    35.21
-----------------------------------------------------------------------------------------
| end of epoch  76 | time: 13.18s | valid loss  5.69 | valid ppl   294.73
-----------------------------------------------------------------------------------------
| epoch  77 |   200/  453 batches | ms/batch 28.30 | loss  3.63 | ppl    37.57
| epoch  77 |   400/  453 batches | ms/batch 28.23 | loss  3.55 | ppl    34.94
------------------

-----------------------------------------------------------------------------------------
| end of epoch  87 | time: 13.33s | valid loss  5.79 | valid ppl   328.40
-----------------------------------------------------------------------------------------
| epoch  88 |   200/  453 batches | ms/batch 28.74 | loss  3.55 | ppl    34.98
| epoch  88 |   400/  453 batches | ms/batch 28.65 | loss  3.48 | ppl    32.53
-----------------------------------------------------------------------------------------
| end of epoch  88 | time: 13.39s | valid loss  5.80 | valid ppl   331.42
-----------------------------------------------------------------------------------------
| epoch  89 |   200/  453 batches | ms/batch 28.70 | loss  3.55 | ppl    34.80
| epoch  89 |   400/  453 batches | ms/batch 28.70 | loss  3.48 | ppl    32.33
-----------------------------------------------------------------------------------------
| end of epoch  89 | time: 13.39s | valid loss  5.81 | valid ppl   334.31
------------