In [1]:
import numpy as np
from functools import reduce

import os
import sys
root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root)

from validation import get_error_propagation_prob, get_accuracy, get_precision, get_recall, get_f1
from src.scrapper import parse_conllu_file
from src.tagger import HiddenMarkovModel, HiddenMarkovModelTagger
from src.visualization import plot_viterbi_path_binary, plot_viterbi_matrix

In [2]:
# Initialize data
train = parse_conllu_file(filepath="../datasets/en_partut-ud-train.conllu")
test = parse_conllu_file(filepath="../datasets/en_partut-ud-test.conllu")
tagger = HiddenMarkovModel(corpus=train).train()

In [3]:
# Compute prediction data
test_predictions = tagger.predict(corpus=test)
print(get_error_propagation_prob(test, test_predictions))

0.2673684210526316


In [4]:
# Compute confusion matrix
tagset = tagger.tagset  # read corpus tagset
cm = tagger.get_confusion_matrix(test, test_predictions)

# Format matrix for pretty printing
cm_formatted = np.array2string(cm, precision=0, separator=' ', suppress_small=True, max_line_width=100)
print(cm_formatted)

[[216   0   0   0   0   0   0   0   1   0   0   0   0   0   0   0  17   0]
 [  0   1   0   0   0   0   0   0   2   0   0   0   0   0   0   0   1   0]
 [  1   0  31  33   1   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  1   0   0 485   0   0   0   0   2   0   0   0   0   0   0   0   0   0]
 [  6   0   0   0 333   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  1   0   0   0   0  86   0   0   1  16   0   0   0   1   0   2   0   2]
 [  1   0   0  11   0   0  70   0  27   2   0   0   0   1   0   0   8   8]
 [  1   0   0   0   0   0   0   0   0   0   1   0   0   0   0   0   0   0]
 [  2   0   0   0   0   0   0   0 732   0   2   0   0   0   0   0   9   9]
 [  1   0   0   0   0   4   0   0   0 434   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  45   0  39   0   0   0   0   0   0   6]
 [  0   0   0   1   0   0   0   1   0   0   0  94   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0  13   0   7

In [5]:
# Accuracy calculations
print("MODEL ACCURACY")
print("--------------")

acc = get_accuracy(cm)
print(f"Model accuracy: {acc:.4f}")
print()


# Precision calculations
print("MODEL PRECISION")
print("---------------")

precs, p_preds, p_micro, p_macro = get_precision(cm)
prec_data = zip(tagset, precs, p_preds)
for tag in prec_data:  # iterate over all classes
    print(f"{tag[0]}: {tag[1]:.4f} over {int(tag[2])} predictions.")

print(f"\nWeighted average model micro precision: {p_micro:.4f}")
print(f"Average model macro precision: {p_macro:.4f}")
print()


# Recall calculations
print("MODEL RECALL")
print("------------")

recalls, r_preds, r_micro, r_macro = get_recall(cm)
recall_data = zip(tagset, recalls, r_preds)
for tag in recall_data:  # iterate over all classes
    print(f"{tag[0]}: {tag[1]:.4f} over {int(tag[2])} predictions.")

print(f"\nWeighted average model micro recall: {r_micro:.4f}")
print(f"Average model macro recall: {r_macro:.4f}")
print()

# F1 calculations
print("MODEL F1")
print("------------")
micro_f1 = get_f1(p_micro, r_micro)
macro_f1 = get_f1(p_macro, r_macro)
print(f"Model micro F1: {micro_f1:.4f}")
print(f"Model macro F1: {macro_f1:.4f}")

MODEL ACCURACY
--------------
Model accuracy: 0.8602

MODEL PRECISION
---------------
aux: 0.8780 over 246 predictions.
_: 1.0000 over 1 predictions.
part: 1.0000 over 31 predictions.
adp: 0.8899 over 545 predictions.
punct: 0.9970 over 334 predictions.
pron: 0.8687 over 99 predictions.
adv: 0.9589 over 73 predictions.
x: 0.0000 over 1 predictions.
noun: 0.7342 over 997 predictions.
det: 0.9602 over 452 predictions.
propn: 0.9070 over 43 predictions.
cconj: 1.0000 over 94 predictions.
sym: 0.0000 over 0 predictions.
sconj: 0.9355 over 31 predictions.
intj: 0.0000 over 0 predictions.
num: 0.9545 over 44 predictions.
verb: 0.8092 over 262 predictions.
adj: 0.8239 over 159 predictions.

Weighted average model micro precision: 0.8713
Average model macro precision: 0.7621

MODEL RECALL
------------
aux: 0.9231 over 234 predictions.
_: 0.2500 over 4 predictions.
part: 0.4697 over 66 predictions.
adp: 0.9939 over 488 predictions.
punct: 0.9823 over 339 predictions.
pron: 0.7890 over 109 predi