In [14]:
import numpy as np
from functools import reduce

import os
import sys
root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root)

from validation import get_error_propagation_prob, get_accuracy, get_precision, get_recall, get_f1
from src.scrapper import parse_conllu_file
from src.tagger import HiddenMarkovModel, HiddenMarkovModelTagger
from src.visualization import plot_viterbi_path_binary, plot_viterbi_matrix

In [15]:
# Initialize data
train = parse_conllu_file(filepath="../datasets/en_gum-ud-train.conllu")
test = parse_conllu_file(filepath="../datasets/en_gum-ud-test.conllu")
tagger = HiddenMarkovModel(corpus=train).train()

In [16]:
# Compute prediction data
test_predictions = tagger.predict(corpus=test)
print(get_error_propagation_prob(test, test_predictions))

0.24769617425300194


In [17]:
# Compute confusion matrix
tagset = tagger.tagset  # read corpus tagset
cm = tagger.get_confusion_matrix(test, test_predictions)

# Format matrix for pretty printing
cm_formatted = np.array2string(cm, precision=0, separator=' ', suppress_small=True, max_line_width=100)
print(cm_formatted)

[[ 179    0    0    0    0    0    0    0   69    0    0    0    2    0    0    0   10    0]
 [   0  881    2   24    0    0    0    0  346    0    0    0    7    0    0    0   67    0]
 [   0    0 2029    3    0    2    0    0    7    0    0    3    0    0    2    0    1    0]
 [   0   52   85  652    0    0   19    0   67    0    0   27    8    0    0    0   15    0]
 [   0    0    0    0  812    0    0    0   10    0    0    0    0    0    0    0  104    0]
 [   0    0    0    2    0  680   18    0    0    0    0    0    0    0    0    0    0    0]
 [   0    1    0    0    0    3 1660    0    2    0    0   55    0    0    0    0    0    0]
 [   0    2   12   30    0    0    5   75   14    0    0    0    1    0    0    0    9    0]
 [   0   13    0    5    1    0    1    0 3290    0    0    0   48    0    0    0  150    0]
 [   0    0    0    0    0    0    0    0  144  237    0    1   17    0    0    0    2    0]
 [   0    0  222    0   71    0    1    0    2    0  122    0    0    

In [18]:
# Accuracy calculations
print("MODEL ACCURACY")
print("--------------")

acc = get_accuracy(cm)
print(f"Model accuracy: {acc:.4f}")
print()


# Precision calculations
print("MODEL PRECISION")
print("---------------")

precs, p_preds, p_micro, p_macro = get_precision(cm)
prec_data = zip(tagset, precs, p_preds)
for tag in prec_data:  # iterate over all classes
    print(f"{tag[0]}: {tag[1]:.4f} over {int(tag[2])} predictions.")

print(f"\nWeighted average model micro precision: {p_micro:.4f}")
print(f"Average model macro precision: {p_macro:.4f}")
print()


# Recall calculations
print("MODEL RECALL")
print("------------")

recalls, r_preds, r_micro, r_macro = get_recall(cm)
recall_data = zip(tagset, recalls, r_preds)
for tag in recall_data:  # iterate over all classes
    print(f"{tag[0]}: {tag[1]:.4f} over {int(tag[2])} predictions.")

print(f"\nWeighted average model micro recall: {r_micro:.4f}")
print(f"Average model macro recall: {r_macro:.4f}")
print()

# F1 calculations
print("MODEL F1")
print("------------")
f1s, predictions, micro_f1, macro_f1 = get_f1(cm)
print(f"Model micro F1: {micro_f1:.4f}")
print(f"Model macro F1: {macro_f1:.4f}")

MODEL ACCURACY
--------------
Model accuracy: 0.8213

MODEL PRECISION
---------------
_: 1.0000 over 179 predictions.
adj: 0.8731 over 1009 predictions.
adp: 0.8261 over 2456 predictions.
adv: 0.8968 over 727 predictions.
aux: 0.8354 over 972 predictions.
cconj: 0.9913 over 686 predictions.
det: 0.9640 over 1722 predictions.
intj: 1.0000 over 75 predictions.
noun: 0.6120 over 5376 predictions.
num: 1.0000 over 237 predictions.
part: 0.9919 over 123 predictions.
pron: 0.9104 over 1551 predictions.
propn: 0.7920 over 500 predictions.
punct: 0.9914 over 2562 predictions.
sconj: 0.9762 over 84 predictions.
sym: 1.0000 over 8 predictions.
verb: 0.7936 over 1899 predictions.
x: 1.0000 over 5 predictions.

Weighted average model micro precision: 0.8459
Average model macro precision: 0.9141

MODEL RECALL
------------
_: 0.6885 over 260 predictions.
adj: 0.6639 over 1327 predictions.
adp: 0.9912 over 2047 predictions.
adv: 0.7049 over 925 predictions.
aux: 0.8769 over 926 predictions.
cconj: 0.