Skip to content

Commit

Permalink
Add multilabel support to Evaluation.test() (#5)
Browse files Browse the repository at this point in the history
Evaluation.test() now supports multi-label classification as well. It
supports all previous standard metrics (precision, recall, f1-score,
accuracy) plus two new ones, 'hamming-lose' and 'exact-match'
(equivalent to 'accuracy'). Once finished, the `test` function also
shows a binary confusion matrix for each possible label.
  • Loading branch information
sergioburdisso committed May 14, 2020
1 parent 4571f99 commit 0a897dd
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 59 deletions.
6 changes: 6 additions & 0 deletions pyss3/__init__.py
Expand Up @@ -148,6 +148,7 @@ def __init__(
:param sn_m: method used to compute the sanction (sn) function, options
are: "vanilla" and "xai" (default: "xai")
:type sn_m: str
:raises: ValueError
"""
self.__name__ = (name or self.__name__).lower()

Expand All @@ -156,6 +157,11 @@ def __init__(
self.__p__ = p or self.__p__
self.__a__ = a or self.__a__

try:
float(self.__s__ + self.__l__ + self.__p__ + self.__a__)
except BaseException:
raise ValueError("hyperparameter values must be numbers")

self.__categories_index__ = {}
self.__categories__ = []
self.__max_fr__ = []
Expand Down
187 changes: 128 additions & 59 deletions pyss3/util.py
Expand Up @@ -9,7 +9,7 @@
from collections import defaultdict

from numpy import mean, linspace, arange
from sklearn.metrics import classification_report, accuracy_score
from sklearn.metrics import classification_report, accuracy_score, hamming_loss
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold

Expand All @@ -20,6 +20,16 @@
import json
import re

try:
from sklearn.metrics import multilabel_confusion_matrix
except ImportError:
print("\033[93m* Your Scikit-learn version does not include `multilabel_confusion_matrix()`.\n"
"* Update Scikit-learn in case you want to work with multi-label classification.\033[0m")

def multilabel_confusion_matrix(*args):
"""Dummy version of multilabel_confusion_matrix."""
return np.array([])

ENCODING = "utf-8"

REGEX_DATE = re.compile(
Expand All @@ -40,9 +50,10 @@

STR_ACCURACY, STR_PRECISION = "accuracy", "precision"
STR_RECALL, STR_F1 = "recall", "f1-score"
STR_HAMMING_LOSS, STR_EXACT_MATCH = "hamming-loss", "exact-match"
METRICS = [STR_PRECISION, STR_RECALL, STR_F1]
EXCP_METRICS = [STR_ACCURACY, "confusion_matrix", "categories"]
AVGS = ["micro avg", "macro avg", "weighted avg"]
EXCP_METRICS = [STR_ACCURACY, STR_HAMMING_LOSS, "confusion_matrix", "categories"]
AVGS = ["micro avg", "macro avg", "weighted avg", "samples avg"]

STR_TEST, STR_FOLD = 'test', 'fold'
STR_MOST_PROBABLE = "most-probable"
Expand Down Expand Up @@ -186,7 +197,8 @@ def __cache_update__():

@staticmethod
def __cache_save_result__(
cache, categories, accuracy, report, conf_matrix, k_fold, i_fold, s, l, p, a
cache, categories, accuracy, report, conf_matrix, k_fold, i_fold,
s, l, p, a, hamming_loss=None
):
"""Compute, update and finally save evaluation results to disk."""
rf = round_fix
Expand All @@ -198,6 +210,8 @@ def __cache_save_result__(
# if there aren't previous best results, initialize them to -1
if cache["accuracy"]["best"]["value"] == {}:
cache["accuracy"]["best"]["value"] = -1
if hamming_loss is not None:
cache["hamming_loss"]["best"]["value"] = -1
for metric, avg in product(METRICS, AVGS):
if avg in report: # scikit-learn > 0.20 does not include 'micro avg' in report
cache[metric][avg]["best"]["value"] = -1
Expand All @@ -208,6 +222,8 @@ def __cache_save_result__(
# if fold results array is empty, create new ones
if cache["accuracy"]["fold_values"][s][l][p][a] == {}:
cache["accuracy"]["fold_values"][s][l][p][a] = [0] * k_fold
if hamming_loss is not None:
cache["hamming_loss"]["fold_values"][s][l][p][a] = [0] * k_fold
cache["confusion_matrix"][s][l][p][a] = [None] * k_fold
for metric, avg in product(METRICS, AVGS):
if avg in report:
Expand All @@ -218,6 +234,8 @@ def __cache_save_result__(

# saving fold results
cache["accuracy"]["fold_values"][s][l][p][a][i_fold] = rf(accuracy)
if hamming_loss is not None:
cache["hamming_loss"]["fold_values"][s][l][p][a][i_fold] = rf(1 - hamming_loss)
for metric, avg in product(METRICS, AVGS):
if avg in report:
cache[metric][avg]["fold_values"][s][l][p][a][i_fold] = rf(report[avg][metric])
Expand All @@ -239,6 +257,16 @@ def __cache_save_result__(
best_acc["s"], best_acc["l"] = s, l
best_acc["p"], best_acc["a"] = p, a

if hamming_loss is not None:
hamloss_avg = rf(mean(cache["hamming_loss"]["fold_values"][s][l][p][a]))
cache["hamming_loss"]["value"][s][l][p][a] = hamloss_avg

best_haml = cache["hamming_loss"]["best"]
if hamloss_avg > best_haml["value"]:
best_haml["value"] = hamloss_avg
best_haml["s"], best_haml["l"] = s, l
best_haml["p"], best_haml["a"] = p, a

for metric, avg in product(METRICS, AVGS):
if avg in report:
metric_avg = rf(mean(cache[metric][avg]["fold_values"][s][l][p][a]))
Expand Down Expand Up @@ -493,10 +521,12 @@ def __classification_report_k_fold__(
raise KeyError(ERROR_NAT % str(metric_target))

@staticmethod
def __plot_confusion_matrices__(cms, classes, info='', max_colums=3):
def __plot_confusion_matrices__(cms, classes, info='', max_colums=3, multilabel=False):
"""Show and plot the confusion matrices."""
import matplotlib.pyplot as plt

categories = classes
classes = ['no', 'yes'] if multilabel else classes
n_cms = len(cms)

rows = int(ceil(n_cms / (max_colums + .0)))
Expand Down Expand Up @@ -528,14 +558,16 @@ def __plot_confusion_matrices__(cms, classes, info='', max_colums=3):

if n_cms == 1:
ax.set_title(title + '\n', fontweight="bold")
elif multilabel:
ax.set_title(categories[axi], fontweight="bold")

if (axi % max_colums) == 0:
ax.set_ylabel('True', fontweight="bold")
ax.set_yticklabels(classes)
else:
ax.tick_params(labelleft=False)

if axi + 1 > n_cms - max_colums:
if axi + 1 > n_cms - max_colums or multilabel:
ax.set_xlabel('Predicted', fontweight="bold")
ax.set_xticklabels(classes)
else:
Expand Down Expand Up @@ -585,11 +617,23 @@ def __evaluation_result__(
if force_show:
Print.verbosity_region_begin(VERBOSITY.VERBOSE, force=True)

n_cats = len(categories)
if def_cat == STR_UNKNOWN:
if categories[-1] != STR_UNKNOWN_CATEGORY:
categories += [STR_UNKNOWN_CATEGORY]
y_pred = [y if y != IDX_UNKNOWN_CATEGORY else n_cats for y in y_pred]
multilabel = clf.__multilabel__
hammingloss = None

if metric == STR_HAMMING_LOSS and not multilabel:
raise ValueError("the '%s' metric is only allowed when in multi-label classification."
% STR_HAMMING_LOSS)

if not multilabel:
n_cats = len(categories)
if def_cat == STR_UNKNOWN:
if categories[-1] != STR_UNKNOWN_CATEGORY:
categories += [STR_UNKNOWN_CATEGORY]
y_pred = [y if y != IDX_UNKNOWN_CATEGORY else n_cats for y in y_pred]
else:
y_pred = membership_matrix(clf, y_pred, labels=False)
y_true = membership_matrix(clf, y_true, labels=False)
hammingloss = hamming_loss(y_pred, y_true)

accuracy = accuracy_score(y_pred, y_true)
Print.show()
Expand All @@ -599,46 +643,51 @@ def __evaluation_result__(
labels=range(len(categories)), target_names=categories
)
)
Print.show(
"\n %s: %.3f"
%
(Print.style.bold("accuracy"), accuracy)
)

unclassified = None
if tag and def_cat == STR_UNKNOWN:
unclassified = sum(map(lambda v: v == n_cats, y_pred))
if not multilabel:
Print.show("\n %s: %.3f" % (Print.style.bold("Accuracy"), accuracy))
else:
Print.show("\n %s: %.3f" % (Print.style.bold("Exact Match Ratio"), accuracy))
Print.show("\n %s: %.3f" % (Print.style.bold("Hamming Loss"), hammingloss))

if not multilabel:
unclassified = None
if tag and def_cat == STR_UNKNOWN:
unclassified = sum(map(lambda v: v == n_cats, y_pred))

if tag and unclassified:
cat_acc = []
for cat in clf.get_categories():
cat_acc.append((
cat,
accuracy_score(
[
clf.get_category_index(cat) if y == n_cats else y
for y in y_pred
],
y_true
)
))

if tag and unclassified:
cat_acc = []
for cat in clf.get_categories():
cat_acc.append((
cat,
accuracy_score(
[
clf.get_category_index(cat) if y == n_cats else y
for y in y_pred
],
y_true
)
))
best_acc = sorted(cat_acc, key=lambda e: -e[1])[0]
Print.warn(
"A better accuracy (%.3f) would be obtained "
"with '%s' as the default category"
%
(best_acc[1], best_acc[0])
)
Print.warn(
"(Since %d%% of the documents were classified as 'unknown')"
%
(unclassified * 100.0 / len(y_true))
)

best_acc = sorted(cat_acc, key=lambda e: -e[1])[0]
Print.warn(
"A better accuracy (%.3f) would be obtained "
"with '%s' as the default category"
%
(best_acc[1], best_acc[0])
)
Print.warn(
"(Since %d%% of the documents were classified as 'unknown')"
%
(unclassified * 100.0 / len(y_true))
)
Print.show()

Print.show()
if not multilabel:
conf_matrix = confusion_matrix(y_true, y_pred)
else:
conf_matrix = multilabel_confusion_matrix(y_true, y_pred)

conf_matrix = confusion_matrix(y_true, y_pred)
report = classification_report(
y_true, y_pred,
labels=range(len(categories)), target_names=categories,
Expand All @@ -652,15 +701,15 @@ def __evaluation_result__(
Evaluation.__cache_get_evaluations__(tag, method, def_cat),
categories, accuracy, report,
conf_matrix, k_fold, i_fold,
s, l, p, a
s, l, p, a, hamming_loss=hammingloss
)

if plot:
Evaluation.__plot_confusion_matrices__(
[conf_matrix], categories,
[conf_matrix] if not multilabel else conf_matrix, categories,
r"$\sigma=%.3f; \lambda=%.3f; \rho=%.3f; \alpha=%.3f$"
%
(s, l, p, a)
(s, l, p, a), multilabel=multilabel
)

warnings.filterwarnings('default')
Expand All @@ -669,6 +718,8 @@ def __evaluation_result__(

if metric == STR_ACCURACY:
return accuracy
elif metric == STR_HAMMING_LOSS:
return hammingloss
else:
if metric_target not in report:
raise KeyError(ERROR_NAT % str(metric_target))
Expand Down Expand Up @@ -783,8 +834,16 @@ def set_classifier(clf):
Evaluation.__cache_load__()

@staticmethod
def clear_cache():
"""Wipe out the evaluation cache."""
def clear_cache(clf=None):
"""
Wipe out the evaluation cache (for the given classifier).
:param clf: the classifier (optional)
:type clf: SS3
"""
if clf is not None:
Evaluation.set_classifier(clf)

Evaluation.__cache__ = None
clf = Evaluation.__clf__
if clf:
Expand Down Expand Up @@ -1123,7 +1182,7 @@ def remove(s=None, l=None, p=None, a=None, method=None, def_cat=None, tag=None,
@staticmethod
def test(
clf, x_test, y_test, def_cat=STR_MOST_PROBABLE, prep=True,
tag=None, plot=True, metric='accuracy', metric_target='macro avg', cache=True
tag=None, plot=True, metric=None, metric_target='macro avg', cache=True
):
"""
Test the model using the given test set.
Expand All @@ -1144,8 +1203,8 @@ def test(
:type clf: SS3
:param x_test: the test set documents, i.e, the list of documents to be classified
:type x_test: list (of str)
:param y_test: the test set category labels, i.e, the list of document labels
:type y_test: list (of str)
:param y_test: the test set with category labels, i.e, the list of document labels
:type y_test: list (of str) or list (of list of str)
:param def_cat: default category to be assigned when SS3 is not
able to classify a document. Options are
'most-probable', 'unknown' or a given category name.
Expand All @@ -1162,7 +1221,10 @@ def test(
:type plot: bool
:param metric: the evaluation metric to return, options are:
'accuracy', 'f1-score', 'precision', or 'recall'
(default: 'accuracy').
When working with multi-label classification problems,
two more options are allowed: 'hamming-loss' and 'exact-match'.
Note: exact match will produce the same result than 'accuracy'.
(default: 'accuracy', or 'hamming-loss' for multi-label case).
:type metric: str
:param metric_target: the target we aim at measuring with the given
metric. Options are: 'macro avg', 'micro avg',
Expand All @@ -1175,7 +1237,7 @@ def test(
:type cache: bool
:returns: the given metric value, by default, the obtained accuracy.
:rtype: float
:raises: EmptyModelError, KeyError
:raises: EmptyModelError, KeyError, ValueError
"""
Evaluation.set_classifier(clf)
tag = tag or Evaluation.__cache_get_default_tag__(clf, x_test)
Expand All @@ -1189,12 +1251,19 @@ def test(
tag, def_cat, s, l, p, a
)

multilabel = clf.__multilabel__
metric = metric or (STR_ACCURACY if not multilabel else STR_HAMMING_LOSS)
metric = metric if metric != STR_EXACT_MATCH else STR_ACCURACY

# if not cached
if not y_pred:
if not y_pred or multilabel:
clf.set_hyperparameters(s, l, p, a)
y_pred = clf.predict(x_test, def_cat, prep=prep, labels=False)
categories = clf.get_categories()
y_test = [clf.get_category_index(y) for y in y_test]
if not multilabel:
y_test = [clf.get_category_index(y) for y in y_test]
else:
y_test = [[clf.get_category_index(y) for y in yy] for yy in y_test]
else:
y_test = _y_test

Expand Down
3 changes: 3 additions & 0 deletions tests/test_pyss3.py
Expand Up @@ -290,6 +290,9 @@ def test_multilabel():

def test_pyss3_ss3(mockers):
"""Test SS3."""
with pytest.raises(ValueError):
clf = SS3("hyperparameter")

clf = SS3(
s=.45, l=.5, p=1, a=0,
cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
Expand Down

0 comments on commit 0a897dd

Please sign in to comment.