Skip to content

Commit

Permalink
add analyze_bert_errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Dec 17, 2018
1 parent c3b794d commit e0e7363
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
44 changes: 43 additions & 1 deletion examples/factrueval.ipynb
Expand Up @@ -75,7 +75,9 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
Expand Down Expand Up @@ -512,6 +514,46 @@
"print(tokens_report)"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
"from modules.utils.plot_metrics import analyze_bert_errors"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [],
"source": [
"res_tokens, res_labels, errors = analyze_bert_errors(dl, preds)"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"88"
]
},
"execution_count": 136,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len([error for error in errors if error])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
28 changes: 28 additions & 0 deletions modules/utils/plot_metrics.py
Expand Up @@ -93,3 +93,31 @@ def get_elmo_span_report(dl, preds, ignore_labels=["O"]):
set_labels.update([y for x, y in spans_true[idx]])
set_labels -= set(ignore_labels)
return flat_classification_report([[y[1] for y in x] for x in spans_true], [[y[1] for y in x] for x in spans_pred], labels=list(set_labels), digits=3)


def analyze_bert_errors(dl, labels, fn=voting_choicer):
errors = []
res_tokens = []
res_labels = []
r_labels = [x.labels for x in dl.dataset]
for f, l_, rl in zip(dl.dataset, labels, r_labels):
label = fn(f.tok_map, l_)
label_r = fn(f.tok_map, rl)
prev_idx = 0
errors_ = []
assert len(label_r) == len(f.tokens) - 1
for idx, (l, rl, t) in enumerate(zip(label, label_r, f.tokens)):
if l != rl:
errors_.append({"token: ": t,
"real_label": rl,
"pred_label": l,
"bert_token": f.bert_tokens[prev_idx:f.tok_map[idx]],
"real_bert_label": f.labels[prev_idx:f.tok_map[idx]],
"pred_bert_label": l_[prev_idx:f.tok_map[idx]],
"text_example": " ".join(f.tokens[1:-1]),
"labels": " ".join(label_r[1:])})
prev_idx = f.tok_map[idx]
errors.append(errors_)
res_tokens.append(f.tokens[1:-1])
res_labels.append(label[1:])
return res_tokens, res_labels, errors

0 comments on commit e0e7363

Please sign in to comment.