# Data augmentation with transformer models for named entity recognition

> Article: https://www.depends-on-the-definition.com/data-augmentation-with-transformers/ by [Tobias Sterbak](https://twitter.com/tobias_sterbak)

> Colab Creator: [Manu Romero](https://twitter.com/mrm8488)

![image](https://d33wubrfki0l68.cloudfront.net/cd99f03175460ab0a43c84b1ad7802adb4409295/3c45b/images/data-augmentation-with-transformers_files/data_augmentation.png)

### Download the dataset

In [None]:
!wget https://raw.githubusercontent.com/mrm8488/NER-English/master/ner_dataset.csv

--2020-08-24 10:06:50--  https://raw.githubusercontent.com/mrm8488/NER-English/master/ner_dataset.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15208151 (15M) [text/plain]
Saving to: ‘ner_dataset.csv’


2020-08-24 10:06:50 (50.0 MB/s) - ‘ner_dataset.csv’ saved [15208151/15208151]



### Install and import required libraries

In [None]:
!pip install -q transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

torch.manual_seed(2020)

print(torch.cuda.get_device_name(torch.cuda.current_device()))
print(torch.cuda.is_available())
print(torch.__version__)


Tesla P100-PCIE-16GB
True
1.6.0+cu101


### Load the dataset and inspect it

In [None]:
import pandas as pd
import numpy as np

data = pd.read_csv("ner_dataset.csv", encoding="latin1")
data = data.fillna(method="ffill")

In [None]:
data.head(20)

Unnamed: 0,Sentence #,Word,POS,Tag
0,Sentence: 1,Thousands,NNS,O
1,Sentence: 1,of,IN,O
2,Sentence: 1,demonstrators,NNS,O
3,Sentence: 1,have,VBP,O
4,Sentence: 1,marched,VBN,O
5,Sentence: 1,through,IN,O
6,Sentence: 1,London,NNP,B-geo
7,Sentence: 1,to,TO,O
8,Sentence: 1,protest,VB,O
9,Sentence: 1,the,DT,O


In [None]:
data.tail(20)

Unnamed: 0,Sentence #,Word,POS,Tag
1048555,Sentence: 47957,.,.,O
1048556,Sentence: 47958,They,PRP,O
1048557,Sentence: 47958,say,VBP,O
1048558,Sentence: 47958,not,RB,O
1048559,Sentence: 47958,all,DT,O
1048560,Sentence: 47958,of,IN,O
1048561,Sentence: 47958,the,DT,O
1048562,Sentence: 47958,rockets,NNS,O
1048563,Sentence: 47958,exploded,VBD,O
1048564,Sentence: 47958,upon,IN,O


In [None]:
data['Tag'].value_counts()

O        887908
B-geo     37644
B-tim     20333
B-org     20143
I-per     17251
B-per     16990
I-org     16784
B-gpe     15870
I-geo      7414
I-tim      6528
B-art       402
B-eve       308
I-art       297
I-eve       253
B-nat       201
I-gpe       198
I-nat        51
Name: Tag, dtype: int64

#### Create a sentence getter

In [None]:
class SentenceGetter(object):
    
    def __init__(self, data):
        self.n_sent = 1
        self.data = data
        self.empty = False
        agg_func = lambda s: [(w, p, t) for w, p, t in zip(s["Word"].values.tolist(),
                                                           s["POS"].values.tolist(),
                                                           s["Tag"].values.tolist())]
        self.grouped = self.data.groupby("Sentence #").apply(agg_func)
        self.sentences = [s for s in self.grouped]
    
    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1
            return s
        except:
            return None


In [None]:
getter = SentenceGetter(data)

In [None]:
sentences = getter.sentences

In [None]:
sentences[0]

[('Thousands', 'NNS', 'O'),
 ('of', 'IN', 'O'),
 ('demonstrators', 'NNS', 'O'),
 ('have', 'VBP', 'O'),
 ('marched', 'VBN', 'O'),
 ('through', 'IN', 'O'),
 ('London', 'NNP', 'B-geo'),
 ('to', 'TO', 'O'),
 ('protest', 'VB', 'O'),
 ('the', 'DT', 'O'),
 ('war', 'NN', 'O'),
 ('in', 'IN', 'O'),
 ('Iraq', 'NNP', 'B-geo'),
 ('and', 'CC', 'O'),
 ('demand', 'VB', 'O'),
 ('the', 'DT', 'O'),
 ('withdrawal', 'NN', 'O'),
 ('of', 'IN', 'O'),
 ('British', 'JJ', 'B-gpe'),
 ('troops', 'NNS', 'O'),
 ('from', 'IN', 'O'),
 ('that', 'DT', 'O'),
 ('country', 'NN', 'O'),
 ('.', '.', 'O')]

In [None]:
tags = ["[PAD]"]
tags.extend(list(set(data["Tag"].values)))
tag2idx = {t: i for i, t in enumerate(tags)}

words = ["[PAD]", "[UNK]"]
words.extend(list(set(data["Word"].values)))
word2idx = {t: i for i, t in enumerate(words)}

In [None]:
test_sentences, val_sentences, train_sentences = sentences[:15000], sentences[15000:20000], sentences[20000:]

### Build a data augmentor with a transformer model
This is based on HF/transformers ```fill-mask``` pipeline

In [None]:
import random
from transformers import pipeline

In [None]:
class TransformerAugmenter():
    """
    Use the pretrained masked language model to generate more
    labeled samples from one labeled sentence.
    """
    
    def __init__(self):
        self.num_sample_tokens = 5
        self.fill_mask = pipeline(
            "fill-mask",
            topk=self.num_sample_tokens,
            model="distilroberta-base"
        )
    
    def generate(self, sentence, num_replace_tokens=3):
        """Return a list of n augmented sentences."""
              
        # run as often as tokens should be replaced
        augmented_sentence = sentence.copy()
        for i in range(num_replace_tokens):
            # join the text
            text = " ".join([w[0] for w in augmented_sentence])
            # pick a token
            replace_token = random.choice(augmented_sentence)
            # mask the picked token
            masked_text = text.replace(
                replace_token[0],
                f"{self.fill_mask.tokenizer.mask_token}",
                1            
            )
            # fill in the masked token with distilRoBERTa
            res = self.fill_mask(masked_text)[random.choice(range(self.num_sample_tokens))]
            # create output samples list
            tmp_sentence, augmented_sentence = augmented_sentence.copy(), []
            for w in tmp_sentence:
                if w[0] == replace_token[0]:
                    augmented_sentence.append((res["token_str"].replace("Ġ", ""), w[1], w[2]))
                else:
                    augmented_sentence.append(w)
            text = " ".join([w[0] for w in augmented_sentence])
        return [sentence, augmented_sentence]


In [None]:
augmenter = TransformerAugmenter()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=480.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=230.0, style=ProgressStyle(description_…






HBox(children=(FloatProgress(value=0.0, description='Downloading', max=331070498.0, style=ProgressStyle(descri…




Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['lm_head.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
train_sentences[0]

[('In', 'IN', 'O'),
 ('other', 'JJ', 'O'),
 ('violence', 'NN', 'O'),
 (',', ',', 'O'),
 ('U.S.', 'NNP', 'B-gpe'),
 ('officials', 'NNS', 'O'),
 ('said', 'VBD', 'O'),
 ('one', 'CD', 'O'),
 ('American', 'JJ', 'B-gpe'),
 ('soldier', 'NN', 'O'),
 ('was', 'VBD', 'O'),
 ('killed', 'VBN', 'O'),
 ('while', 'IN', 'O'),
 ('on', 'IN', 'O'),
 ('patrol', 'NN', 'O'),
 ('in', 'IN', 'O'),
 ('Baghdad', 'NNP', 'B-geo'),
 ('Sunday', 'NNP', 'B-tim'),
 ('.', '.', 'O')]

In [None]:
augmented_sentences = augmenter.generate(train_sentences[0], num_replace_tokens=7); augmented_sentences

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()


[[('In', 'IN', 'O'),
  ('other', 'JJ', 'O'),
  ('violence', 'NN', 'O'),
  (',', ',', 'O'),
  ('U.S.', 'NNP', 'B-gpe'),
  ('officials', 'NNS', 'O'),
  ('said', 'VBD', 'O'),
  ('one', 'CD', 'O'),
  ('American', 'JJ', 'B-gpe'),
  ('soldier', 'NN', 'O'),
  ('was', 'VBD', 'O'),
  ('killed', 'VBN', 'O'),
  ('while', 'IN', 'O'),
  ('on', 'IN', 'O'),
  ('patrol', 'NN', 'O'),
  ('in', 'IN', 'O'),
  ('Baghdad', 'NNP', 'B-geo'),
  ('Sunday', 'NNP', 'B-tim'),
  ('.', '.', 'O')],
 [('among', 'IN', 'O'),
  ('other', 'JJ', 'O'),
  ('violence', 'NN', 'O'),
  (',', ',', 'O'),
  ('U.S.', 'NNP', 'B-gpe'),
  ('officials', 'NNS', 'O'),
  ('said', 'VBD', 'O'),
  ('one', 'CD', 'O'),
  ('American', 'JJ', 'B-gpe'),
  ('soldier', 'NN', 'O'),
  ('was', 'VBD', 'O'),
  ('killed', 'VBN', 'O'),
  ('when', 'IN', 'O'),
  ('on', 'IN', 'O'),
  ('patrol', 'NN', 'O'),
  ('in', 'IN', 'O'),
  ('Baghdad', 'NNP', 'B-geo'),
  ('Sunday', 'NNP', 'B-tim'),
  ('.-', '.', 'O')]]

### Generate an augmented dataset

In [None]:
len(train_sentences)

27959

In [None]:
# only use a thousand senteces with augmentation (it can take about 10 mins)
n_sentences = 1000

augmented_sentences = []
for sentence in tqdm(train_sentences[:n_sentences]):
    augmented_sentences.extend(augmenter.generate(sentence, num_replace_tokens=7))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))




In [None]:
len(augmented_sentences)

2000