In [1]:
def get_index(cur_entry, tokens):
    ans = 0
    for i, tok in enumerate(tokens[:cur_entry['index']-1]):  # -1 because we start with [CLS]
        ans += len(tok.replace('#',''))  
        if i:
            ans += int(tok[0].isalpha())  # whitespace before every word (not subword or punctuation) except the first
        
    ans += cur_entry['word'][0].isalpha() # whitespace before our token (unless it's a subword)
    return ans

# Contains a dirty hack but oh well
def post_process(item):
    if item['word'].startswith('#'):
        item['entity'] = 'I' + item['entity'][1:]
    return item

def insert_index(item, tokens):
    item['start'] = get_index(item, tokens)  
    item['end'] = item['start']+len(item['word'].replace('#',''))
    return item

In [22]:
import spacy
from spacy import displacy
import seaborn as sns
from pathlib import Path

def visualise(text, preds, path_to_output=None):
    ## Step 1: adding entities
    entities = []
    nlp = spacy.blank("nl")  # it should work with any language
    doc = nlp(text+' ') # a hack
    
    
    ner_map = {} 
    cur_type = ''
    cur_start, cur_end = 0, 0

    for pred in preds: 
        ent = pred['entity']
        if ent.startswith('B'): # or pred['start'] > cur_end+1: ## a dirty hack in case it failed to predict 'B'
            ## Adding the previous entity if it's not empty
            if cur_type != '':
                char_span = doc.char_span(cur_start, cur_end, cur_type)
                if char_span:
                    entities.append(char_span)

            ## Processing the new entity
            cur_type = ent[2:]
            if cur_type not in ner_map: 
                ner_map[cur_type] = len(ner_map)+1
            cur_start = pred['start']
            cur_end = pred['end']
        else: ## there's only 'B' and 'I', 'O' is not included
            cur_end = pred['end']

    ## Adding the last one
    if cur_type != '':
        char_span = doc.char_span(cur_start, cur_end, cur_type)
        if char_span:
            entities.append(char_span)
            
    doc.ents = entities
    
    ## Step 2: visualising 
    colours = sns.color_palette("Set2", len(ner_map)).as_hex()
    options = {"ents": list(ner_map.keys()),
               "colors": {ent: colours[ner_map[ent]-1] for ent in ner_map.keys()}
              }
    if not path_to_output:
        displacy_html = displacy.render(doc, style="ent", options=options,jupyter=True)
    else:
        svg = displacy.render(doc, style='ent',
                              options=options, 
                              jupyter=False)
        output_path = Path(path_to_output)
        output_path.open('w', encoding='utf-8').write(svg)

In [21]:
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification


tokenizer = AutoTokenizer.from_pretrained("/ivi/ilps/personal/vprovat/KB/models/BERTje-NER-v2")
model = AutoModelForTokenClassification.from_pretrained("/ivi/ilps/personal/vprovat/KB/models/BERTje-NER-v2")
nlp = pipeline("ner", model=model, tokenizer=tokenizer)

example = "Ik ben Bert de Jong, en ik woon in Zandvoort aan Zee."

ner_results = nlp(example)
res = [post_process(item) for item in ner_results]

In [23]:
tokens = [tokenizer.decode(tok) for tok in tokenizer(example).input_ids][1:-1]
res_for_visualisation = [insert_index(item,tokens) for item in res]

In [26]:
visualise(example, res_for_visualisation)#, path_to_output='plots/NER_example.svg')

In [25]:
res_for_visualisation

[{'entity': 'B-PER',
  'score': 0.9783231,
  'index': 3,
  'word': 'Bert',
  'start': 7,
  'end': 11},
 {'entity': 'I-PER',
  'score': 0.91749096,
  'index': 4,
  'word': 'de',
  'start': 12,
  'end': 14},
 {'entity': 'I-PER',
  'score': 0.97511953,
  'index': 5,
  'word': 'Jong',
  'start': 15,
  'end': 19},
 {'entity': 'B-LOC',
  'score': 0.98678595,
  'index': 11,
  'word': 'Zandvoort',
  'start': 35,
  'end': 44},
 {'entity': 'I-LOC',
  'score': 0.6270297,
  'index': 12,
  'word': 'aan',
  'start': 45,
  'end': 48},
 {'entity': 'I-LOC',
  'score': 0.8423254,
  'index': 13,
  'word': 'Ze',
  'start': 49,
  'end': 51},
 {'entity': 'I-LOC',
  'score': 0.8233484,
  'index': 14,
  'word': '##e',
  'start': 51,
  'end': 52}]

In [269]:
example

'Ik ben Jip de Kip, en ik woon in Zandvoort aan Zee'

In [270]:
for item in res_for_visualisation:
    if example[get_index(item,tokens)] != item['word'].replace('#','')[0]:
        print(get_index(item,tokens), example[get_index(item,tokens)], '!=', item['word'].replace('#','')[0])
#     if example[item['start']:item['end']-1] != item['word'].replace('#',''):
#         print(example[item['start']:item['end']-1],'!=', item['word'].replace('#',''))

Ik
Current start:  2
ben
Current start:  6
Ik
Current start:  2
ben
Current start:  6
Ji
Current start:  9
Ik
Current start:  2
ben
Current start:  6
Ji
Current start:  9
##p
Current start:  10
Ik
Current start:  2
ben
Current start:  6
Ji
Current start:  9
##p
Current start:  10
de
Current start:  13
Ik
Current start:  2
ben
Current start:  6
Ji
Current start:  9
##p
Current start:  10
de
Current start:  13
Kip
Current start:  17
,
Current start:  18
en
Current start:  21
ik
Current start:  24
woon
Current start:  29
in
Current start:  32
Ik
Current start:  2
ben
Current start:  6
Ji
Current start:  9
##p
Current start:  10
de
Current start:  13
Kip
Current start:  17
,
Current start:  18
en
Current start:  21
ik
Current start:  24
woon
Current start:  29
in
Current start:  32
Zandvoort
Current start:  42
Ik
Current start:  2
ben
Current start:  6
Ji
Current start:  9
##p
Current start:  10
de
Current start:  13
Kip
Current start:  17
,
Current start:  18
en
Current start:  21
ik
Curr

In [252]:
'I'.isalpha()

True

In [251]:
tokens

['Ik',
 'ben',
 'Ji',
 '##p',
 'de',
 'Kip',
 ',',
 'en',
 'ik',
 'woon',
 'in',
 'Zandvoort',
 'aan',
 'Ze',
 '##e']