In [1]:
from os import path
from pprint import PrettyPrinter, pprint
from typing import Optional

__DIR__ = globals()['_dh'][0]
data_dir = path.relpath(path.join(__DIR__, "..", "_data"))

In [2]:
import pandas as pd

pd.set_option('display.width', 120)
pd.set_option('display.max_colwidth', 90)

In [3]:
from datasets import load_dataset

dataset = load_dataset("conllpp")
num_classes = dataset["train"].features["ner_tags"].feature.num_classes

display(pd.DataFrame(dataset["train"][:5]))

Reusing dataset conllpp (/Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2)


  0%|          | 0/3 [00:00<?, ?it/s]

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags
0,0,"[EU, rejects, German, call, to, boycott, British, lamb, .]","[22, 42, 16, 21, 35, 37, 16, 21, 7]","[11, 21, 11, 12, 21, 22, 11, 12, 0]","[3, 0, 7, 0, 0, 0, 7, 0, 0]"
1,1,"[Peter, Blackburn]","[22, 22]","[11, 12]","[1, 2]"
2,2,"[BRUSSELS, 1996-08-22]","[22, 11]","[11, 12]","[5, 0]"
3,3,"[The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, t...","[12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 35, 24, 35, 37, 16, 21, 15, 24, 41, 15, 1...","[11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 13, 11, 21, 22, 11, 12, 17, 11, 21, 17, 1...","[0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0..."
4,4,"[Germany, 's, representative, to, the, European, Union, 's, veterinary, committee, Wer...","[22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 22, 38, 15, 22, 24, 20, 37, 21, 15, 24, 1...","[11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 12, 21, 13, 11, 12, 21, 22, 11, 13, 11, 1...","[5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0..."


In [4]:
import torch
from transformers import BertTokenizerFast

null_label = -100   # https://towardsdatascience.com/named-entity-recognition-with-bert-in-pytorch-a454405e0b6a
max_length = 128

tokenizer = BertTokenizerFast.from_pretrained(path.join(data_dir, "pretrain", "tokenizer"))

spc_tok_attr = {"word_ids": [None], "labels": [null_label]}
cls_token = tokenizer(tokenizer.cls_token, add_special_tokens=False) | spc_tok_attr
sep_token = tokenizer(tokenizer.sep_token, add_special_tokens=False) | spc_tok_attr
pad_token = tokenizer(tokenizer.pad_token, add_special_tokens=False) | spc_tok_attr | {"attention_mask": [0]}

def process_dataset(ds, null_label=null_label, max_length=max_length, add_special_tokens=True, num_proc=4):
    def process_sample(sample):
        encoding = tokenizer(sample["tokens"], add_special_tokens=False)
        
        # propagate word ids (based on words from sample["tokens"])
        encoding["word_ids"] = [[i] * len(input_ids) for i, input_ids in enumerate(encoding["input_ids"])]  

        # propagate ner tags as labels
        encoding["labels"] = [[tag] + [null_label] * (len(input_ids)-1) 
                              for tag, input_ids in zip(sample["ner_tags"], encoding["input_ids"])]

        # concat
        encoding = {k: sum(v, []) for k, v in encoding.items()}
        expected_encoding_length = len(encoding["input_ids"]) + (2 if add_special_tokens else 0)
        if max_length is not None:
            pad_length = max_length - expected_encoding_length

        for k, v in encoding.items():

            # append info from special_tokens
            if add_special_tokens:
                v = cls_token[k] + v + sep_token[k]

            # sanity check 1
            assert len(v) == expected_encoding_length, f"expected {k} of length {expected_encoding_length}, got {len(v)}"

            # padding / truncation
            if max_length is not None:
                v = v + pad_token[k] * pad_length if pad_length>0 else v[:max_length]

                # sanity check 2
                assert len(v) == max_length

            encoding[k] = v

        # provide concatenated text and a copy of the words
        encoding["words"] = sample["tokens"]
        encoding["text"] = tokenizer.decode(encoding["input_ids"], skip_special_tokens=True)
        
        return encoding

    return ds.map(process_sample, remove_columns=ds.features, num_proc=num_proc)

train_dataset = process_dataset(dataset["train"])
val_dataset = process_dataset(dataset["validation"])
test_dataset = process_dataset(dataset["test"])

display(pd.DataFrame(train_dataset[:5]))

     

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-36633bac7587cd66.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-e8d6af8df9b88f95.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-cf90090c448cb885.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-d8bab431a5f84004.arrow


     

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-fb0688efb18a311d.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-9aaf6b4821f4ad4a.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-ea5a3c1fd7ed84bd.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-e98e99a7f315e49a.arrow


     

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-54f82d3204575459.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-02bd9d8e4b3164cf.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-ffd59deb17c2fb08.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-8c58579de5e79708.arrow


Unnamed: 0,input_ids,token_type_ids,attention_mask,word_ids,labels,words,text
0,"[2, 11829, 6647, 99, 885, 907, 179, 5989, 15573, 1228, 754, 113, 103, 18, 3, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[None, 0, 1, 1, 2, 3, 4, 5, 5, 6, 7, 7, 7, 8, None, None, None, None, None, None, None...","[-100, 3, 0, -100, 7, 0, 0, 0, -100, 7, 0, -100, -100, 0, -100, -100, -100, -100, -100...","[EU, rejects, German, call, to, boycott, British, lamb, .]",EU rejects German call to boycott British lamb.
1,"[2, 3501, 2503, 3534, 110, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[None, 0, 1, 1, 1, None, None, None, None, None, None, None, None, None, None, None, N...","[-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -...","[Peter, Blackburn]",Peter Blackburn
2,"[2, 37, 1761, 107, 3536, 143, 107, 1801, 17, 8785, 17, 1908, 3, 0, 0, 0, 0, 0, 0, 0, 0...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[None, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, None, None, None, None, None, None, None, None...","[-100, 5, -100, -100, -100, -100, -100, 0, -100, -100, -100, -100, -100, -100, -100, -...","[BRUSSELS, 1996-08-22]",BRUSSELS 1996 - 08 - 22
3,"[2, 192, 2801, 5047, 1411, 201, 389, 1525, 1850, 314, 480, 16107, 159, 219, 885, 1831,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 1, 2, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8, 9, 10, 10, 11, 12, 13, 14, 14, 15, 16, 1...","[-100, 0, 3, 4, 0, 0, 0, -100, -100, 0, 0, -100, -100, 0, 7, 0, -100, 0, 0, 0, 0, -100...","[The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, t...",The European Commission said on Thursday it disagreed with German advice to consumers ...
4,"[2, 1963, 11, 83, 7565, 179, 155, 2801, 2096, 11, 83, 10315, 15768, 8244, 2157, 10340,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9, 10, 10, 10, 11, 11, 11, 11, 12, 13, 14, ...","[-100, 5, 0, -100, 0, 0, 0, 3, 4, 0, -100, 0, -100, 0, 1, -100, -100, 2, -100, -100, -...","[Germany, 's, representative, to, the, European, Union, 's, veterinary, committee, Wer...",Germany's representative to the European Union's veterinary committee Werner Zwingmann...


In [5]:
import numpy as np
from transformers import BertForTokenClassification, Trainer, TrainingArguments

ner_dir = path.join(data_dir, "ner")

model = BertForTokenClassification.from_pretrained(path.join(data_dir, "pretrain", "model"), num_labels=num_classes)

training_args = dict(
    optim = "adamw_torch",
    num_train_epochs = 1,
    per_device_train_batch_size = 64,
    eval_accumulation_steps = 10,
    evaluation_strategy = "steps",
    logging_steps = 10000,
    save_steps = 10000,
    save_total_limit = 3,
)

training_args = TrainingArguments(output_dir = ner_dir,
                                  overwrite_output_dir = True,
                                  **training_args)

trainer = Trainer(model = model,
                  args = training_args,
                  train_dataset = train_dataset,
                  eval_dataset = val_dataset)

Some weights of the model checkpoint at ../_data/pretrain/model were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../_data/pretrain/model and are n

In [6]:
trainer.train()
trainer.save_model(ner_dir)

The following columns in the training set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: word_ids, words, text. If word_ids, words, text are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 14041
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 220


  0%|          | 0/220 [00:00<?, ?it/s]

KeyboardInterrupt: 