Skip to content

ONMT fixes and updates #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Mar 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6832a7d
translate bug fix
bmccann Feb 24, 2017
fe15b4a
README changes for multi-gpu
bmccann Feb 24, 2017
316a524
removing reinit of checkpoint params again
bmccann Feb 24, 2017
99dec5c
using split instead of chunk
bmccann Feb 24, 2017
2078b14
replacing opt.cuda with opt.gpus as needed
bmccann Feb 24, 2017
d685c1c
using ModuleList
bmccann Feb 24, 2017
4d0c84f
default type for start_decay_at
bmccann Feb 24, 2017
f059047
decoder hidden state fix
bmccann Feb 28, 2017
af94796
nn.clip_grad_norm
bmccann Mar 1, 2017
a2caf64
adding src/tgt tokens/s
bmccann Mar 1, 2017
727863b
index in verbose translate was fixed
bmccann Mar 1, 2017
639eb45
bug in total num predicted words
bmccann Mar 1, 2017
bb9d462
Variables in Translator can be volatile
bmccann Mar 1, 2017
2681c91
removing unnecessary def
bmccann Mar 2, 2017
6478147
allowing lowercase option
bmccann Mar 2, 2017
70c3d8f
pointing out one way to do bleu scores in README
bmccann Mar 2, 2017
1c7d2ea
adding files to ignore
bmccann Mar 2, 2017
f45c628
preprocess needs to use lower option
bmccann Mar 2, 2017
36793a0
tips for non-demo mt via flickr30k example
bmccann Mar 2, 2017
a5349bf
cleaning up readme
bmccann Mar 2, 2017
7f518d2
clean up the readme
bmccann Mar 2, 2017
9af532c
spacing in readme
bmccann Mar 2, 2017
e48f620
cudnn decoder
bmccann Mar 2, 2017
d5cfec3
reverting cudnn decoder to lstmcell
bmccann Mar 2, 2017
a8d66b4
new DataParallel allows dim 1; remove unnecessary transposes; add tra…
bmccann Mar 3, 2017
3d91103
mend
bmccann Mar 3, 2017
e4a6730
allows use of models trained on dataset to be trained on another; doe…
bmccann Mar 3, 2017
a2d8bf7
manual unrolling was broken for brnn; patch until varlen rnn replacement
bmccann Mar 3, 2017
6dcb113
allowing learning rate update for non-sgd optimizers
bmccann Mar 6, 2017
1226bde
adding option to shuffle mini-batches
bmccann Mar 6, 2017
8f543a8
adding word level accuracy as a metric
bmccann Mar 6, 2017
4678ecd
touch ups and README updates
bmccann Mar 7, 2017
859412c
allowing validation data to volatile
bmccann Mar 9, 2017
053aadf
num_batches was off by one
bmccann Mar 9, 2017
45e13b5
batch printing was off
bmccann Mar 9, 2017
e10644c
curriculum off by one
bmccann Mar 9, 2017
c4ae24c
accuracy now an average over log_interval batches
bmccann Mar 9, 2017
0e728be
off by one in printing batch number
bmccann Mar 9, 2017
065af29
removing unused variables
bmccann Mar 9, 2017
22e4fb1
saving with state_dict
bmccann Mar 10, 2017
2c02971
state_dicts for translation and optimizer
bmccann Mar 10, 2017
6c8b710
Grouping bash commands together
bmccann Mar 10, 2017
c359f4f
backwards compatibility for checkpoints
bmccann Mar 14, 2017
8cebfba
one more lowercase in dict
bmccann Mar 14, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions OpenNMT/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pred.txt
multi-bleu.perl
*.pt

This comment was marked as off-topic.

This comment was marked as off-topic.

60 changes: 53 additions & 7 deletions OpenNMT/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,69 @@ an open-source (MIT) neural machine translation system.

## Quickstart

OpenNMT consists of three commands:
Use of OpenNMT consists of four steps:

This comment was marked as off-topic.

This comment was marked as off-topic.


0) Download the data.
### 0) Download the data.

```wget https://s3.amazonaws.com/pytorch/examples/opennmt/data/onmt-data.tar && tar -xf onmt-data.tar```

1) Preprocess the data.
### 1) Preprocess the data.

```python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo```

2) Train the model.
### 2) Train the model.

```python train.py -data data/demo-train.pt -save_model model -gpus 1```
```python train.py -data data/demo-train.pt -save_model demo_model -gpus 0```

3) Translate sentences.
### 3) Translate sentences.

```python translate.py -gpu 1 -model model_e13_*.pt -src data/src-test.txt -tgt data/tgt-test.txt -replace_unk -verbose```
```python translate.py -gpu 0 -model demo_model_e13_*.pt -src data/src-test.txt -tgt data/tgt-test.txt -replace_unk -verbose -output demo_pred.txt```

### 4) Evaluate.

```bash
wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl
perl multi-bleu.perl data/tgt-test.txt < demo_pred.txt
```

## WMT'16 Multimodal Translation: Multi30k (de-en)

Data might not come as clean as the demo data. Here is a second example that uses the Moses tokenizer (http://www.statmt.org/moses/) to prepare the Multi30k data from the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html).

### 0) Download the data.

```bash
mkdir -p data/multi30k
wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz && tar -xf training.tar.gz -C data/multi30k && rm training.tar.gz
wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz && tar -xf validation.tar.gz -C data/multi30k && rm validation.tar.gz
wget https://staff.fnwi.uva.nl/d.elliott/wmt16/mmt16_task1_test.tgz && tar -xf mmt16_task1_test.tgz -C data/multi30k && rm mmt16_task1_test.tgz
```

### 1) Preprocess the data.

```bash
wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/tokenizer/tokenizer.perl
sed -i "s/$RealBin\/..\/share\/nonbreaking_prefixes//" tokenizer.perl
wget https://github.com/moses-smt/mosesdecoder/blob/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.de
wget https://github.com/moses-smt/mosesdecoder/blob/master/scripts/share/nonbreaking_prefixes/nonbreaking_prefix.en
for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; perl tokenizer.perl -no-escape -l $l -q < $f > $f.tok; done; done
python preprocess.py -train_src data/multi30k/train.en.tok -train_tgt data/multi30k/train.de.tok -valid_src data/multi30k/val.en.tok -valid_tgt data/multi30k/val.de.tok -save_data data/multi30k
```

### 2) Train the model.

```python train.py -data data/multi30k-train.pt -save_model multi30k_model -gpus 0```

### 3) Translate sentences.

```python translate.py -gpu 0 -model multi30k_model_e13_*.pt -src data/multi30k/test.en.tok -tgt data/multi30k/test.de.tok -replace_unk -verbose -output multi30k_pred.txt```

### 4) Evaluate.

```bash
wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/master/scripts/generic/multi-bleu.perl
perl multi-bleu.perl data/multi30k/test.de.tok < multi30k_pred.txt
```

## Pretrained Models

Expand Down
16 changes: 13 additions & 3 deletions OpenNMT/onmt/Dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import math
import random

import onmt
from torch.autograd import Variable


class Dataset(object):

def __init__(self, srcData, tgtData, batchSize, cuda):
def __init__(self, srcData, tgtData, batchSize, cuda, volatile=False):
self.src = srcData
if tgtData:
self.tgt = tgtData
Expand All @@ -14,7 +17,8 @@ def __init__(self, srcData, tgtData, batchSize, cuda):
self.cuda = cuda

self.batchSize = batchSize
self.numBatches = (len(self.src) + batchSize - 1) // batchSize
self.numBatches = math.ceil(len(self.src)/batchSize)
self.volatile = volatile

def _batchify(self, data, align_right=False):
max_length = max(x.size(0) for x in data)
Expand All @@ -28,7 +32,7 @@ def _batchify(self, data, align_right=False):
if self.cuda:
out = out.cuda()

v = Variable(out)
v = Variable(out, volatile=self.volatile)
return v

def __getitem__(self, index):
Expand All @@ -46,3 +50,9 @@ def __getitem__(self, index):

def __len__(self):
return self.numBatches


def shuffle(self):

This comment was marked as off-topic.

This comment was marked as off-topic.

zipped = list(zip(self.src, self.tgt))
random.shuffle(zipped)
self.src, self.tgt = [x[0] for x in zipped], [x[1] for x in zipped]
6 changes: 5 additions & 1 deletion OpenNMT/onmt/Dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@


class Dict(object):
def __init__(self, data=None):
def __init__(self, data=None, lower=False):
self.idxToLabel = {}
self.labelToIdx = {}
self.frequencies = {}
self.lower = lower

# Special entries will not be pruned.
self.special = []
Expand Down Expand Up @@ -37,6 +38,7 @@ def writeFile(self, filename):
file.close()

def lookup(self, key, default=None):
key = key.lower() if self.lower else key
try:
return self.labelToIdx[key]
except KeyError:
Expand All @@ -60,6 +62,7 @@ def addSpecials(self, labels):

# Add `label` in the dictionary. Use `idx` as its index if given.
def add(self, label, idx=None):
label = label.lower() if self.lower else label
if idx is not None:
self.idxToLabel[idx] = label
self.labelToIdx[label] = idx
Expand Down Expand Up @@ -89,6 +92,7 @@ def prune(self, size):
_, idx = torch.sort(freq, 0, True)

newDict = Dict()
newDict.lower = self.lower

# Add special entries in all cases.
for i in self.special:
Expand Down
42 changes: 15 additions & 27 deletions OpenNMT/onmt/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,26 @@ def __init__(self, opt, dicts):
self.num_directions = 2 if opt.brnn else 1
assert opt.rnn_size % self.num_directions == 0
self.hidden_size = opt.rnn_size // self.num_directions
inputSize = opt.word_vec_size
input_size = opt.word_vec_size

super(Encoder, self).__init__()
self.word_lut = nn.Embedding(dicts.size(),
opt.word_vec_size,
padding_idx=onmt.Constants.PAD)
self.rnn = nn.LSTM(inputSize, self.hidden_size,
self.rnn = nn.LSTM(input_size, self.hidden_size,
num_layers=opt.layers,
dropout=opt.dropout,
bidirectional=opt.brnn)

# self.rnn.bias_ih_l0.data.div_(2)
# self.rnn.bias_hh_l0.data.copy_(self.rnn.bias_ih_l0.data)

if opt.pre_word_vecs_enc is not None:
pretrained = torch.load(opt.pre_word_vecs_enc)
self.word_lut.weight.copy_(pretrained)

def forward(self, input, hidden=None):
batch_size = input.size(0) # batch first for multi-gpu compatibility
emb = self.word_lut(input).transpose(0, 1)
emb = self.word_lut(input)

if hidden is None:
batch_size = emb.size(1)
h_size = (self.layers * self.num_directions, batch_size, self.hidden_size)
h_0 = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
c_0 = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
Expand All @@ -46,17 +44,16 @@ def __init__(self, num_layers, input_size, rnn_size, dropout):
super(StackedLSTM, self).__init__()
self.dropout = nn.Dropout(dropout)
self.num_layers = num_layers
self.layers = nn.ModuleList()

for i in range(num_layers):
layer = nn.LSTMCell(input_size, rnn_size)
self.add_module('layer_%d' % i, layer)
self.layers.append(nn.LSTMCell(input_size, rnn_size))
input_size = rnn_size

def forward(self, input, hidden):
h_0, c_0 = hidden
h_1, c_1 = [], []
for i in range(self.num_layers):
layer = getattr(self, 'layer_%d' % i)
for i, layer in enumerate(self.layers):
h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
input = h_1_i
if i != self.num_layers:
Expand Down Expand Up @@ -87,9 +84,6 @@ def __init__(self, opt, dicts):
self.attn = onmt.modules.GlobalAttention(opt.rnn_size)
self.dropout = nn.Dropout(opt.dropout)

# self.rnn.bias_ih.data.div_(2)
# self.rnn.bias_hh.data.copy_(self.rnn.bias_ih.data)

self.hidden_size = opt.rnn_size

if opt.pre_word_vecs_enc is not None:
Expand All @@ -98,39 +92,33 @@ def __init__(self, opt, dicts):


def forward(self, input, hidden, context, init_output):
emb = self.word_lut(input).transpose(0, 1)

batch_size = input.size(0)

h_size = (batch_size, self.hidden_size)
output = Variable(emb.data.new(*h_size).zero_(), requires_grad=False)
emb = self.word_lut(input)

# n.b. you can increase performance if you compute W_ih * x for all
# iterations in parallel, but that's only possible if
# self.input_feed=False
outputs = []
output = init_output
for i, emb_t in enumerate(emb.chunk(emb.size(0), dim=0)):
for emb_t in emb.split(1):
emb_t = emb_t.squeeze(0)
if self.input_feed:
emb_t = torch.cat([emb_t, output], 1)

output, h = self.rnn(emb_t, hidden)
output, hidden = self.rnn(emb_t, hidden)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

output, attn = self.attn(output, context.t())
output = self.dropout(output)
outputs += [output]

outputs = torch.stack(outputs)
return outputs.transpose(0, 1), h, attn
return outputs, hidden, attn


class NMTModel(nn.Module):

def __init__(self, encoder, decoder, generator):
def __init__(self, encoder, decoder):
super(NMTModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.generator = generator
self.generate = False

def set_generate(self, enabled):
Expand All @@ -153,15 +141,15 @@ def _fix_enc_hidden(self, h):

def forward(self, input):
src = input[0]
tgt = input[1][:, :-1] # exclude last target from inputs
tgt = input[1][:-1] # exclude last target from inputs
enc_hidden, context = self.encoder(src)
init_output = self.make_init_decoder_output(context)

enc_hidden = (self._fix_enc_hidden(enc_hidden[0]),
self._fix_enc_hidden(enc_hidden[1]))

out, dec_hidden, _attn = self.decoder(tgt, enc_hidden, context, init_output)
if self.generate:
if hasattr(self, 'generator') and self.generate:
out = self.generator(out)

return out
16 changes: 4 additions & 12 deletions OpenNMT/onmt/Optim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils import clip_grad_norm

class Optim(object):

Expand Down Expand Up @@ -29,19 +31,9 @@ def __init__(self, params, method, lr, max_grad_norm, lr_decay=1, start_decay_at

def step(self):
# Compute gradients norm.
grad_norm = 0
for param in self.params:
grad_norm += math.pow(param.grad.data.norm(), 2)

grad_norm = math.sqrt(grad_norm)
shrinkage = self.max_grad_norm / grad_norm

for param in self.params:
if shrinkage < 1:
param.grad.data.mul_(shrinkage)

if self.max_grad_norm:
clip_grad_norm(self.params, self.max_grad_norm)

This comment was marked as off-topic.

self.optimizer.step()
return grad_norm

# decay learning rate if val perf does not improve or we hit the start_decay_at limit
def updateLearningRate(self, ppl, epoch):
Expand Down
Loading