Skip to content

Commit

Permalink
XLM-R code and model release (#900)
Browse files Browse the repository at this point in the history
Summary:
TODO:
1) Need to update bibtex entry
2) Need to upload models, spm_vocab and dict.txt to public s3 location.

For Future:

1) I will probably add instructions to finetune on XNLI and NER, POS etc. but currently no timeline for that.
Pull Request resolved: fairinternal/fairseq-py#900

Reviewed By: myleott

Differential Revision: D18333076

Pulled By: myleott

fbshipit-source-id: 3f3d3716fcc41c78d2dd4525f60b519abbd0459c
  • Loading branch information
ngoyal2707 authored and facebook-github-bot committed Nov 5, 2019
1 parent 68dd3e1 commit e23e5ea
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ modeling and other text generation tasks.

### What's New:

- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
- August 2019: [WMT'19 models released](examples/wmt19/README.md)
- July 2019: fairseq relicensed under MIT license
Expand Down
1 change: 1 addition & 0 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ RoBERTa iterates on BERT's pretraining procedure, including training the model l

### What's New:

- November 2019: Multilingual encoder (XLM-RoBERTa) is available [XLM-R](https://github.com/pytorch/fairseq/examples/xlmr).
- September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers).
- August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers).
- August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/master/examples/roberta/wsc#roberta-training-on-winogrande-dataset).
Expand Down
77 changes: 77 additions & 0 deletions examples/xlmr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Unsupervised Cross-lingual Representation Learning at Scale (XLM-RoBERTa)

## Introduction

XLM-R (XLM-RoBERTa) is scaled cross lingual sentence encoder. It is trained on `2.5T` of data across `100` languages data filtered from Common Crawl. XLM-R achieves state-of-the-arts results on multiple cross lingual benchmarks.

## Pre-trained models

Model | Description | # params | Download
---|---|---|---
`xlmr.base.v0` | XLM-R using the BERT-base architecture | 250M | [xlm.base.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz)
`xlmr.large.v0` | XLM-R using the BERT-large architecture | 560M | [xlm.large.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz)

(Note: The above models are still under training, we will update the weights, once fully trained, the results are based on the above checkpoints.)

## Results

**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)**

Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
`roberta.large.mnli` _(TRANSLATE-TEST)_ | 91.3 | 82.9 | 84.3 | 81.24 | 81.74 | 83.13 | 78.28 | 76.79 | 76.64 | 74.17 | 74.05 | 77.5 | 70.9 | 66.65 | 66.81
`xlmr.large.v0` _(TRANSLATE-TRAIN-ALL)_ | 88.7 | 85.2 | 85.6 | 84.6 | 83.6 | 85.5 | 82.4 | 81.6 | 80.9 | 83.4 | 80.9 | 83.3 | 79.8 | 75.9 | 74.3

## Example usage

##### Load XLM-R from torch.hub (PyTorch >= 1.1):
```python
import torch
xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large.v0')
xlmr.eval() # disable dropout (or leave in train mode to finetune)
```

##### Load XLM-R (for PyTorch 1.0 or custom models):
```python
# Download xlmr.large model
wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz
tar -xzvf xlmr.large.v0.tar.gz

# Load the model in fairseq
from fairseq.models.roberta import XLMRModel
xlmr = XLMRModel.from_pretrained('/path/to/xlmr.large.v0', checkpoint_file='model.pt')
xlmr.eval() # disable dropout (or leave in train mode to finetune)
```

##### Apply Byte-Pair Encoding (BPE) to input text:
```python
tokens = xlmr.encode('Hello world!')
assert tokens.tolist() == [ 0, 35378, 8999, 38, 2]
xlmr.decode(tokens) # 'Hello world!'
```

##### Extract features from XLM-R:
```python
# Extract the last layer's features
last_layer_features = xlmr.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 5, 1024])

# Extract all layer's features (layer 0 is the embedding layer)
all_layers = xlmr.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features)
```

## Citation

```bibtex
@article{,
title = {Unsupervised Cross-lingual Representation Learning at Scale},
author = {Alexis Conneau and Kartikay Khandelwal and Naman Goyal
and Vishrav Chaudhary and Guillaume Wenzek and Francisco Guzm\'an
and Edouard Grave and Myle Ott and Luke Zettlemoyer and Veselin Stoyanov
},
journal={},
year = {2019},
}
```
24 changes: 24 additions & 0 deletions fairseq/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,30 @@ def upgrade_state_dict_named(self, state_dict, name):
state_dict[prefix + 'classification_heads.' + k] = v


@register_model('xlmr')
class XLMRModel(RobertaModel):
@classmethod
def hub_models(cls):
return {
'xlmr.base.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz',
'xlmr.large.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz',
}

@classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return RobertaHubInterface(x['args'], x['task'], x['models'][0])


class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""

Expand Down

0 comments on commit e23e5ea

Please sign in to comment.