## Inspect EstRoBERTa's NER errors on the test set

In [1]:
import os, os.path, re

from estnltk import Text
from estnltk.converters import json_to_text

test_data_dir  = '../data/test'

output_fname = 'estroberta_ner_errors_test_corpus.txt'

full_tagset = ['PER', 'LOC_ORG', 'LOC', 'ORG', 'MISC']

assert os.path.exists(test_data_dir)

In [2]:
# Location of the best BERT model
bert_models_dir = '../bert_models'
bertner_model_location = os.path.join( bert_models_dir, 'model_est-roberta_10_bs16_lr5e-05_ep8' )
assert os.path.exists(bertner_model_location)

In [4]:
# Allow relative imports from the root dir
import sys
sys.path.insert(0, '..')

# Load the best model
from bert_ner_tagger import BertNERTagger
bert_ner_tagger_phrases = BertNERTagger(bertner_model_location, bertner_model_location, output_layer='ner',
                                        token_level=False, do_lower_case=False, use_fast=False)

# Load preprocessing utils
from data_preprocessing import TokenizationPreprocessorFixed

In [5]:
from estnltk.converters import json_to_text
from ner_diff_utils import NerDiffFinder
from tqdm import tqdm

def _text_snippet( text_obj, start, end ):
    '''Takes a snippet out of the text, assuring that text boundaries are not exceeded.
       Also, makes newline character explicit in text by replacing it with '\\n'.
    '''
    start = 0 if start < 0 else start
    start = len(text_obj.text) if start > len(text_obj.text) else start
    end   = len(text_obj.text) if end > len(text_obj.text)   else end
    end   = 0 if end < 0 else end
    snippet = text_obj.text[start:end]
    snippet = snippet.replace('\n', '\\n')
    return snippet

def _format_grouped_diff( group, layer_a='gold_ner_flat', layer_b='ner_flat', label_attr='__label', N=40 ):
    '''Adds textual context (N chars left and right) to the given group (of annotation differences) 
       and returns formatted difference lines (list of strings).
    '''
    assert 'gold' in layer_a
    max_len = max(len(layer_a), len(layer_b))
    output_lines = []
    conflict_loc = group['__loc']
    text_obj = None
    if group[layer_a]:
        text_obj = group[layer_a][0].text_object
    else:
        text_obj = group[layer_b][0].text_object
    output_a = [' gold:  ']
    output_b = [' auto:  ']
    a_spans = [(a.start, a.end) for a in group[layer_a]]
    b_spans = [(b.start, b.end) for b in group[layer_b]]
    before_a = '...'+_text_snippet( text_obj, conflict_loc[0]-N, conflict_loc[0] )
    before_b = '...'+_text_snippet( text_obj, conflict_loc[0]-N, conflict_loc[0] )
    if a_spans:
        before_a = '...'+_text_snippet( text_obj, a_spans[0][0]-N, a_spans[0][0] )
    if b_spans:
        before_b = '...'+_text_snippet( text_obj, b_spans[0][0]-N, b_spans[0][0] )
    output_a.append(before_a)
    output_b.append(before_b)
    last_span = None
    for aid, (start,end) in enumerate(a_spans):
        annotation = group[layer_a][aid].annotations[0]
        if last_span:
            if last_span[1] != start:
                output_a.append( _text_snippet( text_obj,last_span[1],start ) )
        output_a.append( '{'+_text_snippet( text_obj,start,end )+'} /'+annotation[label_attr] )
        last_span = (start,end)
    last_span = None
    for bid, (start,end) in enumerate(b_spans):
        annotation = group[layer_b][bid].annotations[0]
        if last_span:
            if last_span[1] != start:
                output_b.append( _text_snippet( text_obj,last_span[1],start ) )
        output_b.append( '{'+_text_snippet( text_obj,start,end )+'} /'+annotation[label_attr] )
        last_span = (start,end)
    after_a = _text_snippet( text_obj, conflict_loc[0], conflict_loc[1]+N )+'...'
    after_b = _text_snippet( text_obj, conflict_loc[0], conflict_loc[1]+N )+'...'
    if a_spans:
        after_a = _text_snippet( text_obj, a_spans[-1][1], a_spans[-1][1]+N )+'...'
    if b_spans:
        after_b = _text_snippet( text_obj, b_spans[-1][1], b_spans[-1][1]+N )+'...'
    output_a.append(after_a)
    output_b.append(after_b)
    output_lines.append( ''.join(output_a) )
    output_lines.append( ''.join(output_b) )
    return output_lines

def write_out( line, fp ):
    '''Writes to file and stdout.'''
    if not isinstance(line, str):
        line = str(line)
    print(line)
    fp.write(line+'\n')

def evaluate_on_auto_morph_preprocess( tagger, input_dir, input_files, output_file ):
    '''Uses (BertNER)Tagger to annotate input_files in the input_dir and compares automatic 
       NE annotations to gold standard NE annotations available in files. 
       Groups annotation differences into categories (such as 'missing', 'extra', 'wrong_label'), 
       adds context to each difference, and finally outputs categorized differences to standard 
       output stream and to output_file.
    '''
    ner_diff_finder = NerDiffFinder('gold_ner', 
                                    'ner', 
                                     old_layer_attr='nertag', 
                                     new_layer_attr='nertag')
    preprocessor = TokenizationPreprocessorFixed()
    doc_count = 0
    grouped_errors = dict()
    for fname in tqdm( os.listdir(input_dir) ):
        if fname.endswith('.json') and fname in input_files:
            text_obj = json_to_text(file=os.path.join(input_dir, fname))
            preprocessor.preprocess( text_obj )
            text_obj.meta['fname'] = fname.replace('.json', '')
            tagger.tag(text_obj)
            #print( len( text_obj['ner'] ) )
            diff_layer, formatted_diffs_str, grouped_diffs, total_diff_gaps = \
                ner_diff_finder.find_difference( text_obj, fname )
            for g_diff in grouped_diffs:
                g_diff['__status'] = None
                g_diff['__entity'] = ''
                if len(g_diff['ner_flat']) == 0 and len(g_diff['gold_ner_flat']) > 0:
                    g_diff['__status'] = 'missing'
                    g_diff['__entity'] = g_diff['gold_ner_flat'][0].text
                elif len(g_diff['ner_flat']) > 0 and len(g_diff['gold_ner_flat']) == 0:
                    g_diff['__status'] = 'extra'
                    g_diff['__entity'] = g_diff['ner_flat'][0].text
                elif len(g_diff['ner_flat']) == 1 and len(g_diff['gold_ner_flat']) == 1:
                    span_a = g_diff['gold_ner_flat'][0]
                    span_b = g_diff['ner_flat'][0]
                    span_a_label = span_a.annotations[0]['__label']
                    span_b_label = span_b.annotations[0]['__label']
                    g_diff['__entity'] = g_diff['gold_ner_flat'][0].text
                    if span_a.start == span_b.start and span_a.end == span_b.end:
                        assert span_a_label != span_b_label
                        g_diff['__status'] = 'wrong_label'
                    else:
                        if span_a_label != span_b_label:
                            g_diff['__status'] = 'wrong_boundaries + wrong_label'
                        else:
                            g_diff['__status'] = 'wrong_boundaries'
                else:
                    g_diff['__entity'] = g_diff['gold_ner_flat'][0].text
                    g_diff['__status'] = 'wrong_boundaries'
                g_diff['__output'] = _format_grouped_diff( g_diff )
                g_diff['__file'] = text_obj.meta['fname']
                assert g_diff['__status'] is not None
                err_type = g_diff['__status']
                if err_type not in grouped_errors:
                    grouped_errors[ err_type ] = []
                grouped_errors[ err_type ].append( g_diff )
                #print( g_diff )
            doc_count += 1
            #print( formatted_diffs_str )
            #if doc_count > 5:
            #    break
    with open(output_file, 'w', encoding='utf-8') as out_f:
        write_out( '', out_f )
        write_out( grouped_errors.keys(), out_f )
        write_out( '', out_f )
        for err_type in sorted(grouped_errors.keys(), key=lambda x : len(grouped_errors[x]), reverse=True):
            write_out( '='*70, out_f )
            write_out( err_type+' ('+str(len(grouped_errors[err_type]))+')', out_f )
            write_out( '='*70, out_f )
            prev_fname = None
            for g_diff in sorted(grouped_errors[err_type], key=lambda x:x['__entity']):
                snippet = g_diff['__output'][0]+' | '+g_diff['__file']
                write_out( snippet, out_f )
                if len(g_diff['__output']) > 0:
                    for out in  g_diff['__output'][1:]:
                        write_out( out, out_f )
                write_out( '', out_f )
                #if prev_fname is not None and prev_fname != g_diff['__file']:
                #    print()
                prev_fname = g_diff['__file']
            write_out( '', out_f )

        write_out( f' Evaluated on {doc_count} documents. ', out_f )

In [6]:
evaluate_on_auto_morph_preprocess( bert_ner_tagger_phrases, test_data_dir, os.listdir(test_data_dir), output_fname )

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 252/252 [04:37<00:00,  1.10s/it]


dict_keys(['missing', 'extra', 'wrong_label', 'wrong_boundaries + wrong_label', 'wrong_boundaries'])

wrong_boundaries (115)
 gold:  ...s, et tema Johan Kripsoniga "Kribsi" ja {"Mõtskribsi"} /LOC_ORG talud pooleks ostnud on; aga maad on we... | Tartu_V6nnu_Ahja_id17889_1885a
 auto:  ...s, et tema Johan Kripsoniga "Kribsi" ja {"Mõtskribsi" talud} /LOC_ORG pooleks ostnud on; aga maad on weel poo...

 gold:  ...k Kiwita kuha jägu on, sest tema on oma {(Laiba) kuha} /LOC_ORG ära ostnud nende rajadega, mis maamõõtj... | Harju_Jyri_Rae_id6623_1890a
 auto:  ... Kiwita kuha jägu on, sest tema on oma ({Laiba} /LOC_ORG) kuha ära ostnud nende rajadega, mis ma...

 gold:  ...rjutaja Wreimanniga õiges om saanud, ja {Adam} /PER Suurmann om sis selle massu tarwis tood... | V6ru_R2pina_Kahkva_id13937_1889a
 auto:  ...rjutaja Wreimanniga õiges om saanud, ja {Adam Suurmann} /PER om sis selle massu tarwis toodud wilja ...

 gold:  ...utty Jaani wargusse, mees sai Raetseppa {Adamist} /PER kätte saadut  j




---