In [None]:
import spacy
from spacy.tokens import DocBin, Doc, Span
from tqdm import tqdm
import traceback
import pandas as pd
import pickle
import random
import numpy as np
from pathlib import Path
import json
spacy.require_gpu()


from spacy.scorer import Scorer
from spacy.tokens import Doc
from spacy.training.example import Example


## evaluate the performance

In [None]:
def spacy_evaluate(ner_model, examples):
    scorer = Scorer()
    # scorer returns recall, precision, and f-1 score for entity types
    example_list = []
    
    for example in examples:
        
        input_text, input_annotations = example
        pred = ner_model(input_text)

        temp = Example.from_dict(pred,input_annotations)
        
        example_list.append(temp)

    scores = scorer.score(example_list)
    return scores

def prepare_examples(line):
    return line['text'], { "entities": [(i['start'], i['end'], i['label']) for i in line['spans']]}

In [None]:
# we had 10 training runs for each N sample sets

DATA_SAMPLES = [100, 200, 300, 400, 500]
TRAIN_RUN = np.arange(0, 10)

In [None]:
%%time

to_plot = []
for sample in tqdm(DATA_SAMPLES, total = len(DATA_SAMPLES)):
    for run in tqdm(TRAIN_RUN, total = len(TRAIN_RUN)):
        load = f"data_curve/{sample}_samples/runs_{run}/model-best"
        hold_out = f"data_curve/holdout_{run}.spacy"
        data = []
        nlp = spacy.blank("en")

        # evaluate only on the holdout set
        doc_bin = DocBin().from_disk(hold_out) 

        # get gold entities
        for doc in doc_bin.get_docs(nlp.vocab):

            spans = []

            for ent in doc.ents:

                spans.append({"start": ent.start_char, "end": ent.end_char, "label": ent.label_})
            data.append({"text": doc.text, "spans": spans})

        # get predicted entities
        nlp = spacy.load(load)

        ner = spacy_evaluate(nlp, [prepare_examples(page) for page in data])['ents_per_type']

        to_plot.append((ner, f"{sample}-{run}"))
        

        # if comparing against a pretrained model
        
        # nlp = spacy.load("en_core_web_trf")

        # ner = spacy_evaluate(nlp, [prepare_examples(page) for page in data])['ents_per_type']

        # to_plot.append((ner, f"{0}-{run}"))
        

In [None]:
rows = []

for category in to_plot:
    for entity in category[0]:
        if entity in ['PERSON', 'LOC', 'IDNUM', 'EMAIL', 'ORG']:
            temp_row = [(category[0][entity]['p'], "precision", entity, category[1]),\
                        (category[0][entity]['r'], "recall", entity, category[1]), \
                       (category[0][entity]['f'], "f1", entity, category[1])]
            rows.extend(temp_row)
df = pd.DataFrame(rows, columns = ["value", "score_type", "entity", "category"])

df['sample'] = df['category'].str.split('-').str[0].astype(int)
df['run'] = df['category'].str.split('-').str[-1].astype(int)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from matplotlib.ticker import AutoMinorLocator

sns.set_style('white', rc={
    'xtick.bottom': True,
    'ytick.left': True,
})

sns.color_palette("Set1")

matplotlib.rcParams.update({"axes.labelsize": 14,
"xtick.labelsize": 14,
"ytick.labelsize": 14,
"legend.fontsize": 14,
"font.size":14})
matplotlib.rc('font', family='Helvetica') 
matplotlib.rc('pdf', fonttype=42)
matplotlib.rc('text', usetex='false') 
matplotlib.rcParams['axes.unicode_minus'] = False

matplotlib.rcParams['xtick.major.size'] = 2 * 2
matplotlib.rcParams['xtick.major.width'] = 0.5 * 2
matplotlib.rcParams['xtick.minor.size'] = 2 * 2
matplotlib.rcParams['xtick.minor.width'] = 0.5 * 2

matplotlib.rcParams['ytick.major.size'] = 2 * 2
matplotlib.rcParams['ytick.major.width'] = 0.5 * 2
matplotlib.rcParams['ytick.minor.size'] = 2 * 2
matplotlib.rcParams['ytick.minor.width'] = 0.5 * 2

sns.set_palette("Set2")
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5), dpi=300)

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_color('black')
ax.spines['left'].set_color('black')
ax.xaxis.label.set_color('black')
ax.tick_params(axis='x', colors='black')
ax.yaxis.label.set_color('black')
ax.tick_params(axis='y', colors='black')
ax.spines['bottom'].set_linewidth(0.5)
ax.spines['left'].set_linewidth(0.5)

sns.pointplot(data=df.loc[df.score_type == 'f1'], x="sample", y="value", hue = 'entity', errorbar = 'ci', scale = 1, \
             capsize=.15, errwidth=2)

sns.despine()
ax.set_ylim(0, 1)
ax.set_xlabel("Samples finetuned")
ax.set_ylabel("F-1 score")
ax.legend(frameon = False)

# plt.savefig("../figures/data_curve_pii_detection.pdf", bbox_inches = "tight", dpi = 300)
plt.show()


In [None]:
df.loc[(df.category.str.contains("500")) & (df.score_type == 'precision')].groupby('entity').agg({"value": 'mean'})

In [None]:
df.loc[(df.category.str.contains("500")) & (df.score_type == 'precision')].groupby('entity').agg({"value": 'std'})

In [None]:
df.loc[(df.category.str.contains("500")) & (df.score_type == 'recall')].groupby('entity').agg({"value": 'mean'})

In [None]:
df.loc[(df.category.str.contains("500")) & (df.score_type == 'recall')].groupby('entity').agg({"value": 'std'})

In [None]:
df.loc[(df.category.str.contains("500")) & (df.score_type == 'f1')].groupby('entity').agg({"value": 'mean'})

In [None]:
df.loc[(df.category.str.contains("500")) & (df.score_type == 'f1')].groupby('entity').agg({"value": 'std'})

# the mistakes are close to the gold truth

In [None]:
data = []
doc_bin = DocBin()

nlp = spacy.load("data_curve/500_samples/runs_1/model-best/")
doc_bin = DocBin().from_disk("data_curve/holdout_1.spacy")  # your file here
# examples = []  # examples in Prodigy's format
total = 0
for doc in doc_bin.get_docs(nlp.vocab):
    pred = nlp(doc.text_with_ws)
  
    gold = [ent.text for ent in doc.ents if ent.label_ == "LOC"]
    
    predict = [ent.text for ent in pred.ents if ent.label_ == "LOC"]
    predict_org = [ent.text for ent in pred.ents if ent.label_ == "ORG"]
    total += len(gold)
    if gold != predict:
#         print('text', doc.text)
        print("real", gold)
        print("predicted", predict)
        print("org", predict_org)
#         print('---------')
        #     print("predicted_org", predict_org)
        print("=========")