In [1]:
from utils import *
import yaml

import logging
logging.getLogger().setLevel(logging.CRITICAL)

import warnings
warnings.filterwarnings('ignore')

args_file = "./args.yaml"

with open(args_file) as f:
    args_dict = yaml.load(f)

extra_args_file = ""
    
if extra_args_file:
    with open(extra_args_file) as f:
        extra_args_dict = yaml.load(f)
        for k,v in extra_args_dict.items():
            args_dict[k] = v

args = Dict2Obj(args_dict)

## Load model

In [2]:
# checkpoint_dir = "./output/11May2020-12:22:15/checkpoint-915"
checkpoint_dir = "./output/15May2020-14:47:51/checkpoint-198"
args['model_name_or_path'] = checkpoint_dir
args['config_name'] = checkpoint_dir
args['tokenizer_name'] = checkpoint_dir

In [3]:
ner_system = BertNerSystem(args)  # Load model
trainer = get_trainer(ner_system, args)  # get the trainer

In [4]:
ner_system.to('cuda:0')

BertNerSystem(
  (model): BertForTokenClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 1024, padding_idx=0)
        (position_embeddings): Embedding(512, 1024)
        (token_type_embeddings): Embedding(2, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=1024, out_features=1024,

In [5]:
trainer.test(ner_system)

Loading test data: 100%|██████████| 543/543 [00:04<00:00, 132.15it/s]
Loading test data: 100%|██████████| 543/543 [00:04<00:00, 135.43it/s]

Testing: 0it [00:00, ?it/s]




Testing: 100%|██████████| 231/231 [00:19<00:00, 11.54it/s]

INFO:utils:***** Test results *****
INFO:utils:exact_nonpositional = 0.757

INFO:utils:exact_subtoken = 0.443

INFO:utils:exact_token = 0.454

INFO:utils:partial_nonpositional = 0.795

INFO:utils:partial_subtoken = 0.485

INFO:utils:partial_token = 0.488



--------------------------------------------------------------------------------
TEST RESULTS
{'exact_nonpositional': 0.7566607460035524,
 'exact_subtoken': 0.44268292682926824,
 'exact_token': 0.45383222691611347,
 'partial_nonpositional': 0.7952167414050823,
 'partial_subtoken': 0.48536036036036034,
 'partial_token': 0.48783248443689875}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 231/231 [00:20<00:00, 11.41it/s]


## Get test data

In [6]:
val_dataset = BertNerDataset(
    ner_system.tokenizer, data_path=args.val_data_path, type_path="val", \
    doc_size=args.doc_max_seq_length, batch_size=args.eval_batch_size
)
dataloader = DataLoader(val_dataset, batch_size=args.eval_batch_size, shuffle=False, drop_last=True)

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 133.81it/s]


In [7]:
batches = list(iter(dataloader))

In [7]:
ner_system.test_step

<bound method BertNerSystem.test_step of BertNerSystem(
  (model): BertForTokenClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 1024, padding_idx=0)
        (position_embeddings): Embedding(512, 1024)
        (token_type_embeddings): Embedding(2, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Li

## Make single prediction

In [28]:
batch = batches[15]

outputs, = ner_system.forward(input_ids=batch['tokens'].to('cuda:0'), attention_mask=batch["attention_mask"].to('cuda:0'))
iob_predictions = torch.argmax(outputs, dim=2)

iob_predictions = iob_predictions.squeeze(dim=0).tolist()[0]
tokens = batch['tokens'].squeeze(dim=0).tolist()[0]

predictions = set()
words = ""
for i,p in enumerate(iob_predictions):
    if tokens[i] == 0: continue
    if p == 0: continue
    if p == 1: 
        predictions.add(str(words[1:]))
        words = ""
    
    sub_token = vocab[tokens[i]]
    if sub_token[:2] == '##':
        sub_token = sub_token[2:]
    else:
        sub_token = ' ' + sub_token
    words += sub_token
else:    
    if len(words):
        predictions.add(str(words[1:]))

if "" in predictions:
    predictions.remove("")

predictions

{'an', 'chronic myelomonocytic leukemia', 'myelodysplastic syndromes'}

## Get all predictions

In [10]:
predictions = []
for batch in tqdm(batches):
    outputs, = ner_system.forward(input_ids=batch['tokens'].to('cuda:0'), attention_mask=batch["attention_mask"].to('cuda:0'))
    batch_predictions = torch.argmax(outputs, dim=2)
    predictions.append(batch_predictions.squeeze(dim=0).tolist())

HBox(children=(FloatProgress(value=0.0, max=462.0), HTML(value='')))




## Test metrics

Standard metrics:
* Exact
* Partial (some overlap)

Modifications to handle sub-tokenisation:
* Token
* Sub-token

Modificiations to handle multiple mentions:
* Positional
* Non-positional (retrieve tokens from document and compare)

In [1]:
import itertools as it

def score_overlap(loc1,loc2):
    range1 = set(range(loc1[0],loc1[1]+1))
    range2 = set(range(loc2[0],loc2[1]+1))
    
    if range1 & range2:
        if range1 == range2:
            return "exact"
        else:
            return "partial"
    else:
        return False
    
def get_locs(labels):
    locs = []
    new_loc = True
    loc_range = []
    for i,l in enumerate(labels):
        if l == 0:
            new_loc = True
            continue
        if l > 0:
            if new_loc:
                if len(loc_range):
                    locs.append((loc_range[0],loc_range[-1]))
                loc_range = []
                new_loc = False
            loc_range.append(i)
    
    return locs

In [12]:
score_map = {'exact': 2, 'partial': 1, False: 0}

In [13]:
batch_i = 5
batch = batches[batch_i]
prediction = predictions[batch_i]
batch_labels = batch['labels'].squeeze(dim=0).tolist()

### Sub-tokens

In [2]:
def test_predictions_subtoken(predictions, targets, batch, score_map={'exact': 2, 'partial': 1, False: 0}):
    prediction_locs = get_locs(predictions)
    target_locs = get_locs(targets)
    
    prediction_matches = {p_loc: False for p_loc in prediction_locs}
    target_matches = {t_loc: False for t_loc in target_locs}

    for p_loc, t_loc in it.product(prediction_locs, target_locs):
        score = score_overlap(p_loc,t_loc)
        if score_map[score] > score_map[prediction_matches[p_loc]]:
            prediction_matches[p_loc] = score
        if score_map[score] > score_map[target_matches[t_loc]]:
            target_matches[t_loc] = score
            
    exacts = len([k for k,v in target_matches.items() if v == 'exact'])
    partials = len([k for k,v in target_matches.items() if v == 'partial'])
    misses = len([k for k,v in target_matches.items() if v == False])
    spurious = len([k for k,v in prediction_matches.items() if v == False])
    
    return exacts, partials, misses, spurious

In [3]:
def get_metrics(tp,fn,fp):
    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    f1 = 2 * precision * recall / (precision + recall)
    
    return {'precision': precision, 'recall': recall, 'f1': f1}

In [39]:
exacts, partials, misses, spurious = test_predictions(prediction, batch_labels)
exact = get_metrics(exacts,misses,spurious)  # exact
partial = get_metrics(exacts+partials,misses,spurious)  # partial

NameError: name 'prediction' is not defined

In [17]:
exact

{'precision': 0.7142857142857143,
 'recall': 0.45454545454545453,
 'f1': 0.5555555555555556}

In [18]:
partial

{'precision': 0.8181818181818182, 'recall': 0.6, 'f1': 0.6923076923076923}

### Tokens

Barely makes any difference

In [4]:
def map_subtokens(subtokens):
    subtoken_map = []
    token_i = 0
    for t in subtokens:
        if not t[:2] == "##":
            token_i += 1
        subtoken_map.append(token_i)

    return subtoken_map

In [20]:
subtokens = [vocab[t] for t in batch['tokens'].squeeze(dim=0).tolist()]

In [21]:
subtoken_map = map_subtokens(subtokens)

In [5]:
def subtoken2token_labels(labels, subtoken_map):
    token_labels = [0 for i in range(max(subtoken_map)+1)]
    for i,t in enumerate(labels):
        token_i = subtoken_map[i]
        if t > token_labels[token_i]:
            token_labels[token_i] = t
    return token_labels

In [23]:
token_targets = subtoken2token_labels(batch_labels, subtoken_map)
token_predictions = subtoken2token_labels(prediction, subtoken_map)

In [24]:
exacts, partials, misses, spurious = test_predictions(token_targets, token_predictions)
exact = get_metrics(exacts,misses,spurious)  # exact
partial = get_metrics(exacts+partials,misses,spurious)  # partial

In [25]:
exact

{'precision': 0.45454545454545453,
 'recall': 0.7142857142857143,
 'f1': 0.5555555555555556}

In [26]:
partial

{'precision': 0.6, 'recall': 0.8181818181818182, 'f1': 0.6923076923076923}

In [6]:
def test_predictions_token(predictions, targets, batch, score_map={'exact': 2, 'partial': 1, False: 0}):
    subtokens = [vocab[t] for t in batch['tokens'].squeeze(dim=0).tolist()]
    subtoken_map = map_subtokens(subtokens)
    targets = subtoken2token_labels(targets, subtoken_map)
    predictions = subtoken2token_labels(predictions, subtoken_map)
    
    return test_predictions_subtoken(predictions, targets, batch, score_map=score_map)

### All predictions

In [7]:
def test_all(predictions, batches, f, f_kwargs={}, token=False):
    exacts = 0
    partials = 0
    misses = 0
    spurious = 0

    for i in range(len(predictions)):
        batch = batches[i]
        prediction = predictions[i]
        target = batch['labels'].squeeze(dim=0).tolist()
        
        batch_exacts, batch_partials, batch_misses, batch_spurious = f(prediction, target, batch, **f_kwargs)
        
        exacts += batch_exacts
        partials += batch_partials
        misses += batch_misses
        spurious += batch_spurious

    exact = get_metrics(exacts,misses,spurious)  # exact
    partial = get_metrics(exacts+partials,misses,spurious)  # partial
    
    return exact, partial

In [60]:
exact, partial = test_all(predictions, token=False, subtoken_map=None)

In [61]:
exact

{'precision': 0.14462123011161243,
 'recall': 0.697594501718213,
 'f1': 0.2395751376868607}

In [62]:
partial

{'precision': 0.18229284903518728,
 'recall': 0.7525773195876289,
 'f1': 0.29349415204678364}

In [63]:
exact, partial = test_all(predictions, token=True, subtoken_map=subtoken_map)

In [64]:
exact

{'precision': 0.14589886281151707,
 'recall': 0.7011627906976744,
 'f1': 0.2415381534147807}

In [65]:
partial

{'precision': 0.18324849606663582,
 'recall': 0.7550047664442326,
 'f1': 0.2949171476447589}

### Non-positional

This indicates a much higher precision and F1.

In [8]:
def sequence_overlap(s1,s2):
    if s1[0] in s2:
        for begin_i in [i for i,t in enumerate(s2) if t == s1[0]]:
            for s1_i,s2_i in enumerate(range(begin_i,len(s2))):
                if s1_i >= len(s1):
                    continue
                if not s1[s1_i] == s2[s2_i]:
                    break
            else:
                return True
            
    return False

def sequence_overlap_twoway(s1,s2):
    if sequence_overlap(s1,s2): 
        return set(s1) & set(s2)
    if sequence_overlap(s2,s1): 
        return set(s1) & set(s2)
    return False

In [9]:
def get_labeled_tokens(labels, tokens):
    labeled_tokens = set()
    token_range = []
    for i,p in enumerate(labels):
        if tokens[i] == 0: continue
        if p == 0: continue
        if p == 1: 
            labeled_tokens.add(tuple(token_range[1:]))
            token_range = []
        token_range.append(tokens[i])
    else:    
        if len(token_range):
            labeled_tokens.add(tuple(token_range[1:]))
            
    if () in labeled_tokens:
        labeled_tokens.remove(())
    
    return labeled_tokens

In [10]:
def test_predictions_nonpositional(predictions, targets, batch, threshold=0.5):
    batch_tokens = batch['tokens'].squeeze(dim=0).tolist()
    prediction_tokens = get_labeled_tokens(predictions, batch_tokens)
    target_tokens = get_labeled_tokens(targets, batch_tokens)
    
    prediction_matches = {p_tokens: False for p_tokens in prediction_tokens}
    target_matches = {t_tokens: False for t_tokens in target_tokens}

    for p_tokens,t_tokens in it.product(prediction_tokens,target_tokens):
        overlap = sequence_overlap_twoway(p_tokens,t_tokens)
        if overlap:
            if ((len(overlap) / len(t_tokens)) > threshold) and ((len(overlap) / len(p_tokens)) > threshold):
                if t_tokens == p_tokens:
                    prediction_matches[p_tokens] = "exact"
                    target_matches[t_tokens] = "exact"
                else:
                    prediction_matches[p_tokens] = "partial"
                    target_matches[t_tokens] = "partial"

    exacts = len([k for k,v in target_matches.items() if v == 'exact'])
    partials = len([k for k,v in target_matches.items() if v == 'partial'])
    misses = len([k for k,v in target_matches.items() if v == False])
    spurious = len([k for k,v in prediction_matches.items() if v == False])
    
    return exacts, partials, misses, spurious

In [99]:
exact, partial = test_all(predictions, f=test_predictions_nonpositional, f_kwargs={'threshold': 0.33})

In [100]:
exact

{'precision': 0.6912568306010929,
 'recall': 0.7386861313868613,
 'f1': 0.7141848976711361}

In [101]:
partial

{'precision': 0.7460674157303371,
 'recall': 0.7876631079478055,
 'f1': 0.7663012117714946}

## Test different models and checkpoints

Bundle the functions together for convenience

In [17]:
from utils import *
import yaml

import logging
logging.getLogger().setLevel(logging.CRITICAL)

import warnings
warnings.filterwarnings('ignore')

In [18]:
def load_model(checkpoint_dir, args_fp="./args.yaml"):
    with open(args_fp) as f:
        args_dict = yaml.load(f)

    args = Dict2Obj(args_dict)
    
    args['model_name_or_path'] = checkpoint_dir
    args['config_name'] = checkpoint_dir
    args['tokenizer_name'] = checkpoint_dir
    
    ner_system = BertNerSystem(args)  # Load model
    trainer = get_trainer(ner_system, args)  # get the trainer
    vocab = {v:k for k,v in ner_system.tokenizer.get_vocab().items()}
    ner_system.model.eval()  # stops dropout
    ner_system.to('cuda:0')
    
    return ner_system, trainer, vocab

In [19]:
def load_test_data(ner_system, fp, doc_size=512, batch_size=1):
    val_dataset = BertNerDataset(
        ner_system.tokenizer, data_path=fp, type_path="val", \
        doc_size=doc_size, batch_size=batch_size
    )
    dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    return list(iter(dataloader))

In [20]:
def generate_predictions(ner_system, batches):
    predictions = []
    for batch in tqdm(batches, desc="Generating predictions"):
        outputs, = ner_system.forward(input_ids=batch['tokens'].to('cuda:0'), attention_mask=batch["attention_mask"].to('cuda:0'))
        batch_predictions = torch.argmax(outputs, dim=2)
        predictions.append(batch_predictions.squeeze(dim=0).tolist())
    return predictions

In [21]:
def all_test_metrics(predictions, batches):
    exact_subtoken, partial_subtoken = test_all(predictions, batches, f=test_predictions_subtoken, f_kwargs={'score_map': {'exact': 2, 'partial': 1, False: 0}})
    exact_token, partial_token = test_all(predictions, batches, f=test_predictions_token, f_kwargs={'score_map': {'exact': 2, 'partial': 1, False: 0}})
    exact_nonpositional, partial_nonpositional = test_all(predictions, batches, f=test_predictions_nonpositional, f_kwargs={'threshold': 0.33})
    return {
        'exact_subtoken': exact_subtoken, 
        'partial_subtoken': partial_subtoken, 
        'exact_token': exact_token, 
        'partial_token': partial_token, 
        'exact_nonpositional': exact_nonpositional, 
        'partial_nonpositional': partial_nonpositional
    }

In [22]:
import pandas as pd

def metrics_table(metrics, metrics_types=['precision', 'recall', 'f1'], metrics_strategies=['exact_subtoken', 'partial_subtoken', 'exact_token', 'partial_token', 'exact_nonpositional', 'partial_nonpositional']):
    table = []
    for metrics_strategy in metrics_strategies:
        metrics_data = metrics[metrics_strategy]
        table_row = []
        for metrics_type in metrics_types:
            table_row.append(metrics_data[metrics_type])
        table.append(table_row)

    return pd.DataFrame(table, columns=metrics_types, index=metrics_strategies)

In [19]:
checkpoint_dir = "./output/11May2020-12:22:15/checkpoint-915"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 118.32it/s]
Generating predictions: 100%|██████████| 462/462 [00:45<00:00, 10.08it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.144621,0.697595,0.239575
partial_subtoken,0.182293,0.752577,0.293494
exact_token,0.156652,0.71413,0.256942
partial_token,0.184835,0.753052,0.296817
exact_nonpositional,0.655172,0.71627,0.68436
partial_nonpositional,0.723837,0.776911,0.749436


In [73]:
checkpoint_dir = "./output/11May2020-12:22:15/checkpoint-1830"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 116.34it/s]
Generating predictions: 100%|██████████| 462/462 [00:45<00:00, 10.10it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.150354,0.760358,0.251063
partial_subtoken,0.181876,0.799438,0.296335
exact_token,0.159523,0.771368,0.264372
partial_token,0.182814,0.799061,0.297552
exact_nonpositional,0.629956,0.75,0.684757
partial_nonpositional,0.664,0.776911,0.716032


In [74]:
checkpoint_dir = "./output/11May2020-12:22:15/checkpoint-2745"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 117.30it/s]
Generating predictions: 100%|██████████| 462/462 [00:45<00:00, 10.05it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.150888,0.743274,0.250851
partial_subtoken,0.18341,0.78538,0.297374
exact_token,0.160163,0.75508,0.264271
partial_token,0.184222,0.784977,0.298412
exact_nonpositional,0.63253,0.740741,0.682372
partial_nonpositional,0.669377,0.770671,0.716461


## Models with extra LM pretraining

In [23]:
checkpoint_dir = "./output/15May2020-14:47:51/checkpoint-234"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 128.59it/s]
Generating predictions: 100%|██████████| 462/462 [00:18<00:00, 24.78it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.308744,0.785037,0.443189
partial_subtoken,0.340972,0.80881,0.479711
exact_token,0.315114,0.788601,0.450296
partial_token,0.342346,0.808451,0.481006
exact_nonpositional,0.734114,0.814471,0.772208
partial_nonpositional,0.772857,0.843994,0.806861


In [24]:
checkpoint_dir = "./output/15May2020-14:47:51/checkpoint-288"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 135.35it/s]
Generating predictions: 100%|██████████| 462/462 [00:18<00:00, 24.93it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.304903,0.785563,0.439299
partial_subtoken,0.33895,0.810684,0.478033
exact_token,0.312243,0.788382,0.447322
partial_token,0.339645,0.808451,0.478333
exact_nonpositional,0.726368,0.812616,0.767075
partial_nonpositional,0.765957,0.842434,0.802377


In [26]:
checkpoint_dir = "./output/15May2020-14:47:51/checkpoint-162"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 135.11it/s]
Generating predictions: 100%|██████████| 462/462 [00:18<00:00, 24.90it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.292356,0.790202,0.426805
partial_subtoken,0.326332,0.81537,0.466113
exact_token,0.299646,0.792924,0.434932
partial_token,0.327163,0.813146,0.466595
exact_nonpositional,0.738333,0.82037,0.777193
partial_nonpositional,0.776034,0.848674,0.81073


In [27]:
checkpoint_dir = "./output/15May2020-14:47:51/checkpoint-180"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 127.99it/s]
Generating predictions: 100%|██████████| 462/462 [00:18<00:00, 24.88it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.303412,0.791314,0.438638
partial_subtoken,0.336557,0.81537,0.476451
exact_token,0.310415,0.794792,0.44646
partial_token,0.338666,0.815023,0.478501
exact_nonpositional,0.737896,0.818519,0.776119
partial_nonpositional,0.775714,0.847114,0.809843


In [28]:
checkpoint_dir = "./output/15May2020-14:47:51/checkpoint-198"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 134.68it/s]
Generating predictions: 100%|██████████| 462/462 [00:18<00:00, 24.88it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.311462,0.799574,0.448296
partial_subtoken,0.346472,0.823805,0.487791
exact_token,0.31937,0.803758,0.457109
partial_token,0.348292,0.823474,0.489534
exact_nonpositional,0.730579,0.818519,0.772052
partial_nonpositional,0.769122,0.847114,0.806236


In [29]:
checkpoint_dir = "./output/15May2020-14:47:51/checkpoint-216"
ner_system, trainer, vocab = load_model(checkpoint_dir, args_fp="./args.yaml")

batches = load_test_data(ner_system, "./data/dev-ner.json")

predictions = generate_predictions(ner_system, batches)

metrics = all_test_metrics(predictions, batches)

df = metrics_table(metrics)
df

Loading val data: 100%|██████████| 543/543 [00:04<00:00, 134.58it/s]
Generating predictions: 100%|██████████| 462/462 [00:18<00:00, 24.91it/s]


Unnamed: 0,precision,recall,f1
exact_subtoken,0.306164,0.802338,0.443205
partial_subtoken,0.339892,0.825679,0.481552
exact_token,0.313385,0.804777,0.451106
partial_token,0.340583,0.823474,0.481868
exact_nonpositional,0.735632,0.820513,0.775758
partial_nonpositional,0.771307,0.847114,0.807435
