Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bcbbe4f
commit d186d1e
Showing
69 changed files
with
19,875 additions
and
16,193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,105 @@ | ||
*/__pycache__/* | ||
*.ipynb_checkpoints/* | ||
*.idea/* | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
.idea/ | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,155 +1,36 @@ | ||
# PyTorch solution of NER task with Google AI's BERT model | ||
## 0. Introduction | ||
## | ||
This repository contains solution of NER task based on BERT withot fine-tuning BERT model. | ||
|
||
This repository contains solution of NER task based on PyTorch [reimplementation](https://github.com/huggingface/pytorch-pretrained-BERT) of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. | ||
### Papers | ||
There are two solutions based on this architecture. | ||
1. [BSNLP 2019 ACL workshop](http://bsnlp.cs.helsinki.fi/shared_task.html): [solution](https://github.com/king-menin/slavic-ner) and [paper](https://arxiv.org/abs/1906.09978) on multilingual shared task. | ||
2. The second place [solution](https://github.com/king-menin/AGRR-2019) of [Dialogue AGRR-2019](https://github.com/dialogue-evaluation/AGRR-2019) task. | ||
|
||
This implementation can load any pre-trained TensorFlow checkpoint for BERT (in particular [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script is provided (see below). | ||
|
||
## 1. Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models)) | ||
|
||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script. | ||
|
||
This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()`. | ||
|
||
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too. | ||
|
||
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch. | ||
|
||
Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: | ||
|
||
```shell | ||
export BERT_BASE_DIR=/path/to/bert/multilingual_L-12_H-768_A-12 | ||
|
||
python3 convert_tf_checkpoint_to_pytorch.py \ | ||
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ | ||
--bert_config_file $BERT_BASE_DIR/bert_config.json \ | ||
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin | ||
``` | ||
|
||
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models). | ||
|
||
There is used the [BERT-Base, Multilingual](https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip) and [BERT-Cased, Multilingual](https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip) (recommended) in this solution. | ||
|
||
## 2. Results | ||
We didn't search best parametres and obtained the following results for no more than <b>10 epochs</b>. | ||
We didn't search best parametres and obtained the following results. | ||
|
||
### Only NER models | ||
#### Model: `BertBiLSTMAttnCRF`. | ||
|
||
| Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook | ||
| Model | Data set | Dev F1 tok | Dev F1 span | Test F1 tok | Test F1 span | ||
|-|-|-|-|-|-| | ||
| [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | ru | <b>0.937</b> | <b>0.883</b> | 4 | [factrueval.ipynb](examples/factrueval.ipynb) | ||
| [Atis](https://github.com/Microsoft/CNTK/tree/master/Examples/LanguageUnderstanding/ATIS/Data) | en | 0.852 | 0.787 | 65 | [conll-2003.ipynb](examples/conll-2003.ipynb) | ||
| [Conll-2003](https://github.com/kyzhouhzau/BERT-NER/tree/master/NERdata) | en | <b>0.945</b> | 0.858 | 5 | [atis.ipynb](examples/atis.ipynb) | ||
|
||
* Factrueval (f1): 0.9163±0.006, best **0.926**. | ||
* Atis (f1): 0.882±0.02, best **0.896** | ||
* Conll-2003 (f1, dev): 0.949±0.002, best **0.951**; 0.892 (f1, test). | ||
|
||
#### Model: `BertBiLSTMAttnNMT`. | ||
|
||
| Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook | ||
|**OURS**|||||| | ||
| M-BERTCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8598 | 0.7676 | ||
| M-BERTNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8603 | 0.7783 | ||
| M-BERTBiLSTMCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8780 | 0.8108 | ||
| M-BERTBiLSTMCRF-BIO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8263 | 0.8051 | ||
| M-BERTBiLSTMNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8594 | 0.7842 | ||
| M-BERTAttnCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8630 | 0.7879 | ||
| M-BERTBiLSTMAttnCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8851 | 0.8244 | ||
| M-BERTBiLSTMAttnNCRF-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8609 | 0.7869 | ||
| M-BERTBiLSTMAttnNCRF-fit_BERT-IO | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | 0.8739 | 0.8201 | ||
|-|-|-|-|-|-| | ||
| [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | ru | 0.925 | 0.827 | 4 | [factrueval-nmt.ipynb](examples/factrueval-nmt.ipynb) | ||
| [Atis](https://github.com/Microsoft/CNTK/tree/master/Examples/LanguageUnderstanding/ATIS/Data) | en | <b>0.919</b> | <b>0.829</b> | 65 | [atis-nmt.ipynb](examples/atis-nmt.ipynb) | ||
| [Conll-2003](https://github.com/kyzhouhzau/BERT-NER/tree/master/NERdata) | en | 0.936 | <b>0.900</b> | 5 | [conll-2003-nmt.ipynb](examples/conll-2003-nmt.ipynb) | ||
|
||
### Joint Models | ||
#### Model: `BertBiLSTMAttnCRFJoint` | ||
|
||
| Dataset | Lang | IOB precision | Span precision | Clf precision | Total spans in test set | Total classes | Notebook | ||
|-|-|-|-|-|-|-|-| | ||
| [Atis](https://github.com/Microsoft/CNTK/tree/master/Examples/LanguageUnderstanding/ATIS/Data) | en | 0.877 | 0.824 | 0.894 | 65 | 17 | [atis-joint.ipynb](examples/atis-joint.ipynb) | ||
|
||
#### Model: `BertBiLSTMAttnNMTJoint` | ||
|
||
| Dataset | Lang | IOB precision | Span precision | Clf precision | Total spans in test set | Total classes | Notebook | ||
|-|-|-|-|-|-|-|-| | ||
| [Atis](https://github.com/Microsoft/CNTK/tree/master/Examples/LanguageUnderstanding/ATIS/Data) | en | 0.913 | 0.820 | 0.888 | 65 | 17 | [atis-joint-nmt.ipynb](examples/atis-joint-nmt.ipynb) | ||
|
||
### Comprasion with ELMo model | ||
We tested `BertBiLSTMCRF`, `BertBiLSTMAttnCRF` and `BertBiLSTMAttnNMT` on russian dataset [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) with freezed `ElmoEmbedder`: | ||
|
||
#### Model `BertBiLSTMCRF`: | ||
|
||
| Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook | ||
|-|-|-|-|-|-| | ||
| [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | ru | 0.903 | 0.851 | 4 | [samples.ipynb](examples_elmo/samples.ipynb) | ||
|
||
#### Model `BertBiLSTMAttnCRF`: | ||
|
||
| Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook | ||
|-|-|-|-|-|-| | ||
| [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | ru | 0.899 | 0.819 | 4 | [factrueval.ipynb](examples_elmo/factrueval.ipynb) | ||
|
||
#### Model `BertBiLSTMAttnNMT`: | ||
|
||
| Dataset | Lang | IOB precision | Span precision | Total spans in test set | Notebook | ||
|-|-|-|-|-|-| | ||
| [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | ru | 0.902 | 0.752 | 4 | [factrueval-nmt.ipynb](examples_elmo/factrueval.ipynb) | ||
|
||
|
||
## 3. Installation, requirements, test | ||
|
||
This code was tested on Python 3.5+. The requirements are: | ||
|
||
- PyTorch (>= 0.4.1) | ||
- tqdm | ||
- tensorflow (for convertion) | ||
|
||
To install the dependencies: | ||
|
||
````bash | ||
pip install -r ./requirements.txt | ||
```` | ||
|
||
## PyTorch neural network models | ||
|
||
All models are organized as `Encoder`-`Decoder`. `Encoder` is a freezed and <i>weighted</i> (as proposed in [elmo](https://allennlp.org/elmo)) bert output from 12 layers. There are three models that is obtained by using different `Decoder`. | ||
|
||
`Encoder`: BertBiLSTM | ||
|
||
1. `BertBiLSTMCRF`: `Encoder` + `Decoder` (BiLSTM + CRF) | ||
2. `BertBiLSTMAttnCRF`: `Encoder` + `Decoder` (BiLSTM + MultiHead Attention + CRF) | ||
3. `BertBiLSTMAttnNMT`: `Encoder` + `Decoder` (LSTM + Bahdanau Attention - NMT Decode) | ||
4. `BertBiLSTMAttnCRFJoint`: `Encoder` + `Decoder` (BiLSTM + MultiHead Attention + CRF) + (PoolingLinearClassifier - for classification) - joint model with classification. | ||
5. `BertBiLSTMAttnNMTJoint`: `Encoder` + `Decoder` (LSTM + Bahdanau Attention - NMT Decode) + (LinearClassifier - for classification) - joint model with classification. | ||
|
||
|
||
## Usage | ||
|
||
### 1. Loading data: | ||
|
||
```from modules.bert_data import BertNerData as NerData``` | ||
|
||
```data = NerData.create(train_path, valid_path, vocab_file)``` | ||
|
||
### 2. Create model: | ||
|
||
```from modules.bert_models import BertBiLSTMCRF``` | ||
|
||
```model = BertBiLSTMCRF.create(len(data.label2idx), bert_config_file, init_checkpoint_pt, enc_hidden_dim=256)``` | ||
|
||
### 3. Create learner: | ||
|
||
```from modules.train import NerLearner``` | ||
|
||
```learner = NerLearner(model, data, best_model_path="/datadrive/models/factrueval/exp_final.cpt", lr=0.01, clip=1.0, sup_labels=data.id2label[5:], t_total=num_epochs * len(data.train_dl))``` | ||
|
||
### 4. Learn your NER model: | ||
|
||
```learner.fit(2, target_metric='prec')``` | ||
|
||
### 5. Predict on new data: | ||
|
||
```from modules.data.bert_data import get_bert_data_loader_for_predict``` | ||
|
||
```dl = get_bert_data_loader_for_predict(data_path + "valid.csv", learner)``` | ||
|
||
```learner.load_model(best_model_path)``` | ||
|
||
```preds = learner.predict(dl)``` | ||
|
||
|
||
* For more detailed instructions of using BERT model see [samples.ipynb](examples/samples.ipynb). | ||
* For more detailed instructions of using ELMo model see [samples.ipynb](examples_elmo/samples.ipynb). | ||
| BERTBiLSTMCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9624 | 0.9273 | - | - | ||
| BERTBiLSTMCRF-BIO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9530 | 0.9236 | - | - | ||
| B-BERTBiLSTMCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9635 | 0.9277 | - | - | ||
| B-BERTBiLSTMCRF-BIO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9536 | 0.9156 | - | - | ||
| B-BERTBiLSTMAttnCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9571 | 0.9114 | - | - | ||
| B-BERTBiLSTMAttnNCRF-IO | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.9631 | 0.9197 | - | - | ||
|**Current SOTA**|||||| | ||
| DeepPavlov-RuBERT-NER | [FactRuEval](https://github.com/dialogue-evaluation/factRuEval-2016) | - | - | - | **0.8266** | ||
| CSE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | - | - | **0.931** | - | ||
| BERT-LARGE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.966 | - | 0.928 | - | ||
| BERT-BASE | [CoNLL-2003](https://github.com/synalp/NER/tree/master/corpus/CoNLL-2003) | 0.964 | - | 0.924 | - |
Oops, something went wrong.