# Setup

In [1]:
from os import path
import pandas as pd
from pprint import PrettyPrinter, pprint
from typing import Optional

__DIR__ = globals()['_dh'][0]
data_dir = path.relpath(path.join(__DIR__, "..", "_data"))

pp = PrettyPrinter(indent=2, width=120)

pd.set_option('display.width', 120)
pd.set_option('display.max_colwidth', 90)

In [2]:
# Settings
_colab_install = True
_testing = True

# Parameters
tokenizer_dir = path.join(data_dir, "pretrain", "tokenizer")
model_dir = path.join(data_dir, "pretrain", "model")
ner_dir = path.join(data_dir, "ner")

null_label = -100   # https://towardsdatascience.com/named-entity-recognition-with-bert-in-pytorch-a454405e0b6a
max_length = 128

training_args = dict(
    optim = "adamw_torch",
    num_train_epochs = 5,
    per_device_train_batch_size = 64,
    eval_accumulation_steps = 10,

    evaluation_strategy = "epoch",
    logging_strategy = "epoch",
    save_strategy = "epoch",
    
    load_best_model_at_end = True
    save_total_limit = 3,
)

## Process settings / parameters

In [3]:
if _testing:
    training_args.update(dict(
        num_train_epochs = 1
    ))

In [4]:
from collections import OrderedDict

if _colab_install:
    try:
        import google.colab
        
        colab_install_script = path.join(__DIR__, "..", "colab_install.sh")

        if not path.isfile(colab_install_script):
            script_url = "https://raw.githubusercontent.com/yenson-lau/pii-remediation/main/colab_install.sh"
            !wget $script_url -O $colab_install_script

        !bash $colab_install_script
        print()

    except ModuleNotFoundError:
        pass

config = OrderedDict(
    tokenizer_dir = tokenizer_dir,
    model_dir = model_dir,
    ner_dir = ner_dir,

    null_label = null_label,
    max_length = max_length,

    training_args = training_args,
)

print("NER finetuning on conllpp dataset:")
pp.pprint(config)

NER finetuning on conllpp dataset:
OrderedDict([ ('tokenizer_dir', '../_data/pretrain/tokenizer'),
              ('model_dir', '../_data/pretrain/model'),
              ('ner_dir', '../_data/ner'),
              ('null_label', -100),
              ('max_length', 128),
              ( 'training_args',
                { 'eval_accumulation_steps': 10,
                  'evaluation_strategy': 'epoch',
                  'logging_strategy': 'epoch',
                  'num_train_epochs': 1,
                  'optim': 'adamw_torch',
                  'per_device_train_batch_size': 64,
                  'save_total_limit': 3})])


# Load / process dataset

In [5]:
from datasets import load_dataset

dataset = load_dataset("conllpp")
num_classes = dataset["train"].features["ner_tags"].feature.num_classes

display(pd.DataFrame(dataset["train"][:5]))

Reusing dataset conllpp (/Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2)


  0%|          | 0/3 [00:00<?, ?it/s]

Unnamed: 0,id,tokens,pos_tags,chunk_tags,ner_tags
0,0,"[EU, rejects, German, call, to, boycott, British, lamb, .]","[22, 42, 16, 21, 35, 37, 16, 21, 7]","[11, 21, 11, 12, 21, 22, 11, 12, 0]","[3, 0, 7, 0, 0, 0, 7, 0, 0]"
1,1,"[Peter, Blackburn]","[22, 22]","[11, 12]","[1, 2]"
2,2,"[BRUSSELS, 1996-08-22]","[22, 11]","[11, 12]","[5, 0]"
3,3,"[The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, t...","[12, 22, 22, 38, 15, 22, 28, 38, 15, 16, 21, 35, 24, 35, 37, 16, 21, 15, 24, 41, 15, 1...","[11, 12, 12, 21, 13, 11, 11, 21, 13, 11, 12, 13, 11, 21, 22, 11, 12, 17, 11, 21, 17, 1...","[0, 3, 4, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0..."
4,4,"[Germany, 's, representative, to, the, European, Union, 's, veterinary, committee, Wer...","[22, 27, 21, 35, 12, 22, 22, 27, 16, 21, 22, 22, 38, 15, 22, 24, 20, 37, 21, 15, 24, 1...","[11, 11, 12, 13, 11, 12, 12, 11, 12, 12, 12, 12, 21, 13, 11, 12, 21, 22, 11, 13, 11, 1...","[5, 0, 0, 0, 0, 3, 4, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0..."


In [6]:
import torch
from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained(tokenizer_dir)

spc_tok_attr = {"word_ids": [None], "labels": [null_label]}
cls_token = {**tokenizer(tokenizer.cls_token, add_special_tokens=False), **spc_tok_attr}
sep_token = {**tokenizer(tokenizer.sep_token, add_special_tokens=False), **spc_tok_attr}
pad_token = {**tokenizer(tokenizer.pad_token, add_special_tokens=False), **spc_tok_attr, "attention_mask": [0]}

def process_dataset(ds, null_label=null_label, max_length=max_length, add_special_tokens=True, num_proc=4):
    def process_sample(sample):
        encoding = tokenizer(sample["tokens"], add_special_tokens=False)
        
        # propagate word ids (based on words from sample["tokens"])
        encoding["word_ids"] = [[i] * len(input_ids) for i, input_ids in enumerate(encoding["input_ids"])]  

        # propagate ner tags as labels
        encoding["labels"] = [[tag] + [null_label] * (len(input_ids)-1) 
                              for tag, input_ids in zip(sample["ner_tags"], encoding["input_ids"])]

        # concat
        encoding = {k: sum(v, []) for k, v in encoding.items()}
        expected_encoding_length = len(encoding["input_ids"]) + (2 if add_special_tokens else 0)
        if max_length is not None:
            pad_length = max_length - expected_encoding_length

        for k, v in encoding.items():

            # append info from special_tokens
            if add_special_tokens:
                v = cls_token[k] + v + sep_token[k]

            # sanity check 1
            assert len(v) == expected_encoding_length, f"expected {k} of length {expected_encoding_length}, got {len(v)}"

            # padding / truncation
            if max_length is not None:
                v = v + pad_token[k] * pad_length if pad_length>0 else v[:max_length]

                # sanity check 2
                assert len(v) == max_length

            encoding[k] = v

        # provide concatenated text and a copy of the words
        encoding["words"] = sample["tokens"]
        encoding["text"] = tokenizer.decode(encoding["input_ids"], skip_special_tokens=True)
        
        return encoding

    return ds.map(process_sample, remove_columns=ds.features, num_proc=num_proc)

train_dataset = process_dataset(dataset["train"])
val_dataset = process_dataset(dataset["validation"])
test_dataset = process_dataset(dataset["test"])

display(pd.DataFrame(train_dataset[:5]))
display(pd.DataFrame(val_dataset[:5]))

     

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-074e680a38e10812.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-3dcb986fe396d27d.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-7d54e14e0c35d295.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-a32997c640cc8418.arrow


     

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-bf1d9453b72b05ba.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-847866618d3c1cb2.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-1b5305b561029d08.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-a2be676345c8ca8a.arrow


     

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-805d7a9ad640102e.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-02562645a2bb1a93.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-3790734082dac5d9.arrow


 

Loading cached processed dataset at /Users/yenson/.cache/huggingface/datasets/conllpp/conllpp/1.0.0/04f15f257dff3fe0fb36e049b73d51ecdf382698682f5e590b7fb13898206ba2/cache-86ab6f968f43387c.arrow


Unnamed: 0,input_ids,token_type_ids,attention_mask,word_ids,labels,words,text
0,"[2, 11829, 6647, 99, 885, 907, 179, 5989, 15573, 1228, 754, 113, 103, 18, 3, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[None, 0, 1, 1, 2, 3, 4, 5, 5, 6, 7, 7, 7, 8, None, None, None, None, None, None, None...","[-100, 3, 0, -100, 7, 0, 0, 0, -100, 7, 0, -100, -100, 0, -100, -100, -100, -100, -100...","[EU, rejects, German, call, to, boycott, British, lamb, .]",EU rejects German call to boycott British lamb.
1,"[2, 3501, 2503, 3534, 110, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[None, 0, 1, 1, 1, None, None, None, None, None, None, None, None, None, None, None, N...","[-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -...","[Peter, Blackburn]",Peter Blackburn
2,"[2, 37, 1761, 107, 3536, 143, 107, 1801, 17, 8785, 17, 1908, 3, 0, 0, 0, 0, 0, 0, 0, 0...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[None, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, None, None, None, None, None, None, None, None...","[-100, 5, -100, -100, -100, -100, -100, 0, -100, -100, -100, -100, -100, -100, -100, -...","[BRUSSELS, 1996-08-22]",BRUSSELS 1996 - 08 - 22
3,"[2, 192, 2801, 5047, 1411, 201, 389, 1525, 1850, 314, 480, 16107, 159, 219, 885, 1831,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 1, 2, 3, 4, 5, 5, 5, 6, 7, 7, 7, 8, 9, 10, 10, 11, 12, 13, 14, 14, 15, 16, 1...","[-100, 0, 3, 4, 0, 0, 0, -100, -100, 0, 0, -100, -100, 0, 7, 0, -100, 0, 0, 0, 0, -100...","[The, European, Commission, said, on, Thursday, it, disagreed, with, German, advice, t...",The European Commission said on Thursday it disagreed with German advice to consumers ...
4,"[2, 1963, 11, 83, 7565, 179, 155, 2801, 2096, 11, 83, 10315, 15768, 8244, 2157, 10340,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9, 10, 10, 10, 11, 11, 11, 11, 12, 13, 14, ...","[-100, 5, 0, -100, 0, 0, 0, 3, 4, 0, -100, 0, -100, 0, 1, -100, -100, 2, -100, -100, -...","[Germany, 's, representative, to, the, European, Union, 's, veterinary, committee, Wer...",Germany's representative to the European Union's veterinary committee Werner Zwingmann...


Unnamed: 0,input_ids,token_type_ids,attention_mask,word_ids,labels,words,text
0,"[2, 14613, 15718, 144, 15734, 17, 47, 145, 15718, 4679, 5626, 12414, 106, 137, 139, 14...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 6, 6, 7, 7...","[-100, 0, -100, -100, -100, 0, 3, -100, -100, -100, -100, -100, -100, -100, -100, -100...","[CRICKET, -, LEICESTERSHIRE, TAKE, OVER, AT, TOP, AFTER, INNINGS, VICTORY, .]",CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY.
1,"[2, 47, 116, 140, 146, 116, 140, 1801, 17, 8785, 17, 1450, 3, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[None, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, None, None, None, None, None, None, None, None...","[-100, 5, -100, -100, -100, -100, -100, 0, -100, -100, -100, -100, -100, -100, -100, -...","[LONDON, 1996-08-30]",LONDON 1996 - 08 - 30
2,"[2, 1730, 2703, 457, 17, 2945, 156, 2712, 4365, 6593, 1589, 902, 200, 3623, 201, 7767,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 1, 2, 2, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14, 14, 1...","[-100, 7, 8, 0, -100, -100, -100, 1, 2, -100, 0, 0, 0, 0, 0, 0, 0, 3, -100, -100, -100...","[West, Indian, all-rounder, Phil, Simmons, took, four, for, 38, on, Friday, as, Leices...",West Indian all - rounder Phil Simmons took four for 38 on Friday as Leicestershire be...
3,"[2, 3551, 3556, 201, 1765, 16, 1743, 16, 829, 228, 1422, 17, 3802, 212, 2006, 7200, 70...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 10, 11, 12, 13, 14, 15, 15, 16, 17, 18, 19,...","[-100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100, -100, 0, 0, 0, 3, 0, 3, -100, 0, 3, 0, 0, 0...","[Their, stay, on, top, ,, though, ,, may, be, short-lived, as, title, rivals, Essex, ,...","Their stay on top, though, may be short - lived as title rivals Essex, Derbyshire and ..."
4,"[2, 946, 18672, 1913, 3394, 189, 489, 200, 7801, 201, 155, 3449, 4645, 246, 13264, 211...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[None, 0, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 14, 14, 14, 15, 16, 17,...","[-100, 0, 0, 3, -100, -100, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 3, -100, -100, -100, 0, 0...","[After, bowling, Somerset, out, for, 83, on, the, opening, morning, at, Grace, Road, ,...","After bowling Somerset out for 83 on the opening morning at Grace Road, Leicestershire..."


# NER finetuning

In [7]:
# For evaluation
import numpy as np

np.random.seed(0)

eval_dataset = val_dataset

def get_entities(pred, sample):
    valid_sequence_length = sum(sample["attention_mask"]) - 2   # subtract cls / sep tokens
    pred = torch.argmax(pred, axis=-1).flatten()[1:valid_sequence_length+1]
    pred_idxs = torch.nonzero(pred).flatten()

    words = {i: word for i, word in enumerate(sample["words"])}
    word_ids = sample["word_ids"][1:valid_sequence_length+1]

    entities = OrderedDict()
    for idx in pred_idxs:
        word = words.get(word_ids[idx], "[INV]")
        entities[word] = entities.get(word, []) + [int(pred[idx])]

    return entities

def eval_random_samps(eval_preds, n_samps=5):
    preds = eval_preds.predictions

    print("\nEVALUATING ON RANDOM SAMPLES:\n")
    for idx in np.random.permutation(len(preds))[:n_samps]:
        sample = eval_dataset[int(idx)]
        pp.pprint(sample["text"])
        pp.pprint(get_entities(torch.from_numpy(preds[idx]), sample))
        print()

    return dict()

In [8]:
from transformers import BertForTokenClassification, Trainer, TrainingArguments, TrainerCallback

model = BertForTokenClassification.from_pretrained(path.join(data_dir, "pretrain", "model"), num_labels=num_classes)

train_args = TrainingArguments(output_dir = ner_dir,
                                  overwrite_output_dir = True,
                                  **training_args)

trainer = Trainer(model = model,
                  args = train_args,
                  compute_metrics=eval_random_samps,
                  train_dataset = train_dataset,
                  eval_dataset = eval_dataset)

Some weights of the model checkpoint at ../_data/pretrain/model were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at ../_data/pretrain/model and are n

In [9]:
trainer.train()
trainer.save_model(ner_dir)

The following columns in the training set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: text, words, word_ids. If text, words, word_ids are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 14041
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 220


  0%|          | 0/220 [00:00<?, ?it/s]

The following columns in the evaluation set don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: text, words, word_ids. If text, words, word_ids are not expected by `BertForTokenClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 3250
  Batch size = 8


{'loss': 0.4908, 'learning_rate': 0.0, 'epoch': 1.0}


  0%|          | 0/407 [00:00<?, ?it/s]



Training completed. Do not forget to share your model on huggingface.co/models =)


Saving model checkpoint to ../_data/ner
Configuration saved in ../_data/ner/config.json


EVALUATING ON RANDOM SAMPLES:
'World Group II, first round ( March 1 - 2 )'
OrderedDict([('World', [3]), ('Group', [5]), ('II', [4])])

'M. Maynard run out 1'
OrderedDict([('M.', [3]), ('Maynard', [2])])

'Manchester City 3 1 0 2 2 3 3'
OrderedDict([('Manchester', [3, 4]), ('City', [4])])

'NFL AMERICAN FOOTBALL - RANDALL CUNNINGHAM RETIRES.'
OrderedDict()

("The detention of veteran dissident Wang Donghai showed China's determination to crush any vestige of dissent during "
 "the current profound transitions in the nation's leadership, a human rights activist said on Saturday.")
OrderedDict([('Donghai', [1, 5]), ('China', [5])])

{'eval_loss': 0.32919201254844666, 'eval_runtime': 102.5562, 'eval_samples_per_second': 31.69, 'eval_steps_per_second': 3.969, 'epoch': 1.0}
{'train_runtime': 1531.6525, 'train_samples_per_second': 9.167, 'train_steps_per_second': 0.144, 'train_loss': 0.4907617742365057, 'epoch': 1.0}


Model weights saved in ../_data/ner/pytorch_model.bin


# Evaluation

In [10]:
np.random.seed(0)

samples = np.random.permutation(len(val_dataset))[:5]
samples = [val_dataset[int(i)] for i in samples]

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def to_tensor(sample):
    return torch.tensor(sample).view(1,-1).to(device)

preds = [model(input_ids=to_tensor(sample["input_ids"]),
               attention_mask=to_tensor(sample["attention_mask"]),
               token_type_ids=to_tensor(sample["token_type_ids"])
               ).logits.cpu() for sample in samples]

pred_ents = [get_entities(pred, sample) for pred, sample in zip(preds, samples)]

In [12]:
for sample, pred in zip(samples, pred_ents):
    pp.pprint(sample["text"])
    pp.pprint(pred)
    print()

('Derbyshire, nine - wicket winners over Worcestershire, and Surrey, who thrashed Warwickshire by an innings and 164 '
 'runs, can instead take the day off along with rivals Leicestershire, who beat Somerset inside two days.')
OrderedDict([('winners', [1]), ('Surrey', [3]), ('thrashed', [7]), ('Leicestershire', [3])])

'Fulham 4 3 0 1 5 3 9'
OrderedDict([('Fulham', [3, 4])])

'Mahala is a Moslem village on Bosnian Serb republic territory.'
OrderedDict([('Moslem', [7]), ('Bosnian', [7]), ('Serb', [7])])

('Nyerere arrived in Rome this week on a private visit and held talks with the U. S. special envoy to Burundi, Howard '
 "Wolpe, and the Sant'Egidio Community, an Italian Roman Catholic organisation which has been monitoring Burundi "
 'closely.')
OrderedDict([ ('Rome', [7]),
              ('U.S.', [5, 1]),
              ('envoy', [1]),
              ('Burundi', [5, 7]),
              ('Howard', [1]),
              ('Egidio', [3]),
              ('Italian', [5]),
              ('Roman',