Skip to content

Commit

Permalink
Architecture settings and readme updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Edunov committed Sep 15, 2017
1 parent e734b0f commit a15acdb
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 98 deletions.
30 changes: 30 additions & 0 deletions CONTRIBUTING.md
@@ -0,0 +1,30 @@
# Contributing to FAIR Sequence-to-Sequence Toolkit (PyTorch)
We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

## Coding Style
We try to follow the PEP style guidelines and encourage you to as well.

## License
By contributing to FAIR Sequence-to-Sequence Toolkit, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
30 changes: 15 additions & 15 deletions README.md
@@ -1,7 +1,7 @@
# Introduction
FAIR Sequence-to-Sequence Toolkit (PyTorch)

This is a PyTorch version of [fairseq](https://github.com/facebookresearch/fairseq), a sequence-to-sequence learning toolkit from Facebook AI Research. The original authors of this reimplementation are (in no particular order) Sergey Edunov, Myle Ott, and Sam Gross. The toolkit implements the fully convolutional model described in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122). The toolkit features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. We provide pre-trained models for English to French and English to German translation.
This is a PyTorch version of [fairseq](https://github.com/facebookresearch/fairseq), a sequence-to-sequence learning toolkit from Facebook AI Research. The original authors of this reimplementation are (in no particular order) Sergey Edunov, Myle Ott, and Sam Gross. The toolkit implements the fully convolutional model described in [Convolutional Sequence to Sequence Learning](https://arxiv.org/abs/1705.03122) and features multi-GPU training on a single machine as well as fast beam search generation on both CPU and GPU. We provide pre-trained models for English to French and English to German translation.

![Model](fairseq.gif)

Expand All @@ -27,8 +27,9 @@ If you use the code in your paper, then please cite it as:
Currently fairseq-py requires PyTorch from the GitHub repository. There are multiple ways of installing it.
We suggest using [Miniconda3](https://conda.io/miniconda.html) and the following instructions.

* Install Miniconda3 from https://conda.io/miniconda.html create and activate python 3 environment.
* Install Miniconda3 from https://conda.io/miniconda.html; create and activate a Python 3 environment.

* Install PyTorch:
```
conda install gcc numpy cudnn nccl
conda install magma-cuda80 -c soumith
Expand All @@ -44,25 +45,22 @@ pip install -r requirements.txt
NO_DISTRIBUTED=1 python setup.py install
```


Install fairseq by cloning the GitHub repository and by running

* Install fairseq-py by cloning the GitHub repository and running:
```
pip install -r requirements.txt
python setup.py build
python setup.py develop
```

# Quick Start

The following command-line tools are available:
* `python preprocess.py`: Data pre-processing: build vocabularies and binarize training data
* `python train.py`: Train a new model on one or multiple GPUs
* `python generate.py`: Translate pre-processed data with a trained model
* `python generate.py -i`: Translate raw text with a trained model
* `python score.py`: BLEU scoring of generated translations against reference translations


# Quick Start

## Evaluating Pre-trained Models [TO BE ADAPTED]
First, download a pre-trained model along with its vocabularies:
```
Expand Down Expand Up @@ -100,7 +98,7 @@ Check [below](#pre-trained-models) for a full list of pre-trained models availab
## Training a New Model

### Data Pre-processing
The fairseq source distribution contains an example pre-processing script for
The fairseq-py source distribution contains an example pre-processing script for
the IWSLT 2014 German-English corpus.
Pre-process and binarize the data as follows:
```
Expand All @@ -118,11 +116,10 @@ This will write binarized data that can be used for model training to `data-bin/
Use `python train.py` to train a new model.
Here a few example settings that work well for the IWSLT 2014 dataset:
```
$ mkdir -p trainings/fconv
$ mkdir -p checkpoints/fconv
$ CUDA_VISIBLE_DEVICES=0 python train.py data-bin/iwslt14.tokenized.de-en \
--lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
--encoder-layers "[(256, 3)] * 4" --decoder-layers "[(256, 3)] * 3" \
--encoder-embed-dim 256 --decoder-embed-dim 256 --save-dir trainings/fconv
--arch fconv_iwslt_de_en --save-dir checkpoints/fconv
```

By default, `python train.py` will use all available GPUs on your machine.
Expand All @@ -135,7 +132,7 @@ You may need to use a smaller value depending on the available GPU memory on you
Once your model is trained, you can generate translations using `python generate.py` **(for binarized data)** or `python generate.py -i` **(for raw text)**:
```
$ python generate.py data-bin/iwslt14.tokenized.de-en \
--path trainings/fconv/checkpoint_best.pt \
--path checkpoints/fconv/checkpoint_best.pt \
--batch-size 128 --beam 5
| [de] dictionary: 35475 types
| [en] dictionary: 24739 types
Expand Down Expand Up @@ -172,9 +169,12 @@ $ python generate.py data-bin/wmt14.en-fr.newstest2014 \
...
| Translated 3003 sentences (95451 tokens) in 136.3s (700.49 tokens/s)
| Timings: setup 0.1s (0.1%), encoder 1.9s (1.4%), decoder 108.9s (79.9%), search_results 0.0s (0.0%), search_prune 12.5s (9.2%)
TODO: update scores (should be same as score.py)
| BLEU4 = 43.43, 68.2/49.2/37.4/28.8 (BP=0.996, ratio=1.004, sys_len=92087, ref_len=92448)
# Word-level BLEU scoring:
# Scoring with score.py:
$ grep ^H /tmp/gen.out | cut -f3- | sed 's/@@ //g' > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- | sed 's/@@ //g' > /tmp/gen.out.ref
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
TODO: update scores
BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_len=81194)
Expand All @@ -186,6 +186,6 @@ BLEU4 = 40.55, 67.6/46.5/34.0/25.3 (BP=1.000, ratio=0.998, sys_len=81369, ref_le
* Google group: https://groups.google.com/forum/#!forum/fairseq-users

# License
fairseq is BSD-licensed.
fairseq-py is BSD-licensed.
The license applies to the pre-trained models as well.
We also provide an additional patent grant.
15 changes: 10 additions & 5 deletions fairseq/models/__init__.py
Expand Up @@ -6,9 +6,14 @@
# can be found in the PATENTS file in the same directory.
#

from .fconv import *
from . import fconv

__all__ = [
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de',
'fconv_wmt_en_fr',
]

__all__ = ['fconv']

arch_model_map = {}
for model in __all__:
archs = locals()[model].get_archs()
for arch in archs:
assert arch not in arch_model_map, 'Duplicate model architecture detected: {}'.format(arch)
arch_model_map[arch] = model
124 changes: 79 additions & 45 deletions fairseq/models/fconv.py
Expand Up @@ -430,56 +430,90 @@ def backward(ctx, grad):
return grad * ctx.scale, None


def fconv_iwslt_de_en(dataset, dropout, **kwargs):
encoder_convs = [(256, 3)] * 4
decoder_convs = [(256, 3)] * 3
return fconv(dataset, dropout, 256, encoder_convs, 256, decoder_convs, **kwargs)


def fconv_wmt_en_ro(dataset, dropout, **kwargs):
convs = [(512, 3)] * 20
return fconv(dataset, dropout, 512, convs, 512, convs, **kwargs)


def fconv_wmt_en_de(dataset, dropout, **kwargs):
convs = [(512, 3)] * 9 # first 10 layers have 512 units
convs += [(1024, 3)] * 4 # next 3 layers have 768 units
convs += [(2048, 1)] * 2 # final 2 layers are 1x1
return fconv(dataset, dropout, 768, convs, 768, convs,
decoder_out_embed_dim=512,
**kwargs)


def fconv_wmt_en_fr(dataset, dropout, **kwargs):
convs = [(512, 3)] * 6 # first 5 layers have 512 units
convs += [(768, 3)] * 4 # next 4 layers have 768 units
convs += [(1024, 3)] * 3 # next 4 layers have 1024 units
convs += [(2048, 1)] * 1 # next 1 layer is 1x1
convs += [(4096, 1)] * 1 # final 1 layer is 1x1
return fconv(dataset, dropout, 768, convs, 768, convs,
decoder_out_embed_dim=512,
**kwargs)


def fconv(dataset, dropout, encoder_embed_dim, encoder_convolutions,
decoder_embed_dim, decoder_convolutions, attention=True,
decoder_out_embed_dim=256, max_positions=1024):
def get_archs():
return [
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
]


def _check_arch(args):
"""Check that the specified architecture is valid and not ambiguous."""
if args.arch not in get_archs():
raise ValueError('Unknown fconv model architecture: {}'.format(args.arch))
if args.arch != 'fconv':
# check that architecture is not ambiguous
for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers',
'decoder_out_embed_dim']:
if hasattr(args, a):
raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch))


def parse_arch(args):
_check_arch(args)

if args.arch == 'fconv_iwslt_de_en':
args.encoder_embed_dim = 256
args.encoder_layers = '[(256, 3)] * 4'
args.decoder_embed_dim = 256
args.decoder_layers = '[(256, 3)] * 3'
args.decoder_out_embed_dim = 256
elif args.arch == 'fconv_wmt_en_ro':
args.encoder_embed_dim = 512
args.encoder_layers = '[(512, 3)] * 20'
args.decoder_embed_dim = 512
args.decoder_layers = '[(512, 3)] * 20'
args.decoder_out_embed_dim = 512
elif args.arch == 'fconv_wmt_en_de':
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
elif args.arch == 'fconv_wmt_en_fr':
convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
else:
assert args.arch == 'fconv'

# default architecture
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_attention = getattr(args, 'decoder_attention', 'True')
return args


def build_model(args, dataset):
padding_idx = dataset.dst_dict.pad()

encoder = Encoder(
len(dataset.src_dict),
embed_dim=encoder_embed_dim,
convolutions=encoder_convolutions,
dropout=dropout,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
padding_idx=padding_idx,
max_positions=max_positions)
max_positions=args.max_positions,
)
decoder = Decoder(
len(dataset.dst_dict),
embed_dim=decoder_embed_dim,
convolutions=decoder_convolutions,
out_embed_dim=decoder_out_embed_dim,
attention=attention,
dropout=dropout,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
padding_idx=padding_idx,
max_positions=max_positions)
max_positions=args.max_positions,
)
return FConvModel(encoder, decoder, padding_idx)
35 changes: 24 additions & 11 deletions fairseq/options.py
Expand Up @@ -109,22 +109,35 @@ def add_generation_args(parser):


def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
group.add_argument('--arch', '-a', default='fconv', metavar='ARCH',
choices=models.__all__,
help='model architecture ({})'.format(', '.join(models.__all__)))
group.add_argument('--encoder-embed-dim', default=512, type=int, metavar='N',
group = parser.add_argument_group(
'Model configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have model-independent default values.
argument_default=argparse.SUPPRESS,
)

# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority)
# Note: --arch cannot be combined with --encoder/decoder-* arguments.
group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', choices=models.arch_model_map.keys(),
help='model architecture ({})'.format(', '.join(models.arch_model_map.keys())))
group.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
group.add_argument('--encoder-layers', default='[(512, 3)] * 20', type=str, metavar='EXPR',
group.add_argument('--encoder-layers', type=str, metavar='EXPR',
help='encoder layers [(dim, kernel_size), ...]')
group.add_argument('--decoder-embed-dim', default=512, type=int, metavar='N',
group.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
group.add_argument('--decoder-layers', default='[(512, 3)] * 20', type=str, metavar='EXPR',
group.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
group.add_argument('--decoder-attention', default='True', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
group.add_argument('--decoder-out-embed-dim', default=256, type=int, metavar='N',
group.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
group.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')

# These arguments have default values independent of the model:
group.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
group.add_argument('--label-smoothing', default=0, type=float, metavar='D',
Expand Down

0 comments on commit a15acdb

Please sign in to comment.