Skip to content

Commit

Permalink
Change multilabel classification default category
Browse files Browse the repository at this point in the history
``predict`` and ``classify_multilabel`` now have "unknown" as the
default category when working with multi-label classification. When this
category is used it will return an empty list [] to indicate that the
predicted category is unknown.

In addition, ``kmean_multilabel_size`` returns 0 when the centroid of
the two clusters have the same value, which means that ``predict`` or
``classify_multilabel`` will return an empty list [] when predicted
confidence values are all equal.
  • Loading branch information
sergioburdisso committed May 19, 2020
1 parent 15657ee commit 8b2ea60
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 38 deletions.
57 changes: 31 additions & 26 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,22 +857,16 @@ def __predict_fast__(
else:
if pred_cv.sum() == 0:
if def_cat == IDX_UNKNOWN_CATEGORY:
y_pred[doc_idx] = [STR_UNKNOWN_CATEGORY if labels else def_cat]
y_pred[doc_idx] = []
else:
y_pred[doc_idx] = [self.get_category_name(def_cat) if labels else def_cat]
else:
r = sorted(
[
(i, pred_cv[i])
for i in range(pred_cv.size)
],
key=lambda e: -e[1]
)
r = sorted([(i, pred_cv[i])
for i in range(pred_cv.size)],
key=lambda e: -e[1])
if labels:
y_pred[doc_idx] = [
self.get_category_name(cat_i)
for cat_i, _ in r[:kmean_multilabel_size(r)]
]
y_pred[doc_idx] = [self.get_category_name(cat_i)
for cat_i, _ in r[:kmean_multilabel_size(r)]]
else:
y_pred[doc_idx] = [cat_i for cat_i, _ in r[:kmean_multilabel_size(r)]]

Expand Down Expand Up @@ -1931,6 +1925,9 @@ def learn(self, doc, cat, n_grams=1, prep=True, update=True):
"""
self.__cv_cache__ = None

if not doc or not cat:
return

try:
doc = doc.decode(ENCODING)
except UnicodeEncodeError: # for python 2 compatibility
Expand Down Expand Up @@ -2143,7 +2140,7 @@ def classify_label(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=True)

return cat if labels else self.get_category_index(cat)

def classify_multilabel(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=True):
def classify_multilabel(self, doc, def_cat=STR_UNKNOWN, labels=True, prep=True):
"""
Classify a given document returning multiple category labels.
Expand All @@ -2156,7 +2153,7 @@ def classify_multilabel(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=
:param def_cat: default category to be assigned when SS3 is not
able to classify a document. Options are
"most-probable", "unknown" or a given category name.
(default: "most-probable")
(default: "unknown")
:type def_cat: str
:param labels: whether to return the category labels or just the
category indexes (default: True)
Expand All @@ -2171,7 +2168,7 @@ def classify_multilabel(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=

if not r or not r[0][1]:
if not def_cat or def_cat == STR_UNKNOWN:
cat = STR_UNKNOWN_CATEGORY
return []
elif def_cat == STR_MOST_PROBABLE:
cat = self.get_most_probable_category()
else:
Expand All @@ -2185,10 +2182,8 @@ def classify_multilabel(self, doc, def_cat=STR_MOST_PROBABLE, labels=True, prep=
else:
__other_idx__ = self.get_category_index(STR_OTHERS_CATEGORY)
if labels:
result = [
self.get_category_name(cat_i)
for cat_i, _ in r[:kmean_multilabel_size(r)]
]
result = [self.get_category_name(cat_i)
for cat_i, _ in r[:kmean_multilabel_size(r)]]
# removing "hidden" special category ("[other]")
if __other_idx__ != IDX_UNKNOWN_CATEGORY and STR_OTHERS_CATEGORY in result:
result.remove(STR_OTHERS_CATEGORY)
Expand Down Expand Up @@ -2313,7 +2308,7 @@ def predict_proba(self, x_test, prep=True, leave_pbar=True):
]

def predict(
self, x_test, def_cat=STR_MOST_PROBABLE,
self, x_test, def_cat=None,
labels=True, multilabel=False, prep=True, leave_pbar=True
):
"""
Expand All @@ -2324,6 +2319,8 @@ def predict(
:param def_cat: default category to be assigned when SS3 is not
able to classify a document. Options are
"most-probable", "unknown" or a given category name.
(default: "most-probable", or "unknown" for
multi-label classification)
:type def_cat: str
:param labels: whether to return the list of category names or just
category indexes
Expand All @@ -2347,11 +2344,17 @@ def predict(
if not self.__categories__:
raise EmptyModelError

multilabel = multilabel or self.__multilabel__

if def_cat is None:
def_cat = STR_UNKNOWN if multilabel else STR_MOST_PROBABLE

if not def_cat or def_cat == STR_UNKNOWN:
Print.info(
"default category was set to 'unknown' (its index will be -1)",
offset=1
)
if not multilabel:
Print.info(
"default category was set to 'unknown' (its index will be -1)",
offset=1
)
else:
if def_cat == STR_MOST_PROBABLE:
Print.info(
Expand All @@ -2364,8 +2367,6 @@ def predict(
if self.get_category_index(def_cat) == IDX_UNKNOWN_CATEGORY:
raise InvalidCategoryError

multilabel = multilabel or self.__multilabel__

if self.get_ngrams_length() == 1 and self.__summary_ops_are_pristine__():
return self.__predict_fast__(x_test, def_cat=def_cat, labels=labels,
multilabel=multilabel, prep=prep,
Expand Down Expand Up @@ -2547,6 +2548,10 @@ def kmean_multilabel_size(res):
clust = {"neg": [], "pos": []} # clusters (2 clusters: "pos" and "neg")
new_cent_neg = res[-1][1]
new_cent_pos = res[0][1]

if new_cent_neg == new_cent_pos:
return 0

while (cent["pos"] != new_cent_pos) or (cent["neg"] != new_cent_neg):
cent["neg"], cent["pos"] = new_cent_neg, new_cent_pos
clust["neg"], clust["pos"] = [], []
Expand Down
1 change: 1 addition & 0 deletions pyss3/resources/live_test/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ <h5 style="margin-top: 0;" ng-if="multilabel">
<span ng-if="multilabel" ng-repeat="cv_i in ss3.cvns" ng-show="is_cat_active(cv_i)">
<div class="chip pointer" ng-class="{'label-ok': is_in_golden_true(cv_i), 'label-nok': !is_in_golden_true(cv_i)}" ng-click="select_cat(cv_i[0])">{{ss3.ci[cv_i[0]]}}</div>
</span>
<span ng-if="get_n_active_cats() == 0">N/A</span>
</h5>
<h5 style="margin-top: 0;" ng-if="!multilabel">
<!-- <span style="color: black" id="hashtag">Main Category: </span> -->
Expand Down
15 changes: 15 additions & 0 deletions pyss3/resources/live_test/js/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ app.controller("mainCtrl", function($scope) {
$scope.chart.obj.setSelection({row:null,colum:null});
}

$scope.get_n_active_cats = function(){
var c = 0;
for (let i=0; i < $scope.ss3.cvns.length; i++){
if ($scope.is_cat_active($scope.ss3.cvns[i]))
c++;
else
break;
}
return c;
}

$scope.is_cat_active = function (cat_info) {
var icat = Number.isInteger(cat_info)? cat_info : cat_info[0];
return active_cats.indexOf(icat) != -1;
Expand Down Expand Up @@ -281,6 +292,10 @@ app.controller("mainCtrl", function($scope) {
var new_cent_neg = cats[cats.length - 1][2];
var new_cent_pos = cats[0][2];
var active_cats = null;

if (new_cent_neg == new_cent_pos)
return [];

while (cent.pos != new_cent_pos || cent.neg != new_cent_neg){
cent.neg = new_cent_neg;
cent.pos = new_cent_pos;
Expand Down
2 changes: 1 addition & 1 deletion pyss3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def multilabel_confusion_matrix(*args):
ERROR_NAT += "(excepted: %s or a category label)" % (", ".join(["'%s'" % a for a in AVGS]))
ERROR_CNE = "the classifier has not been evaluated yet"
ERROR_CNA = ("a classifier has not yet been assigned "
"(try using `Evaluation.set_classifier (clf)`)")
"(try using `Evaluation.set_classifier(clf)`)")
ERROR_IKV = "`k_fold` argument must be an integer greater than or equal to 2"
ERROR_IKT = "`k_fold` argument must be an integer"
ERROR_INGV = "`n_grams` argument must be a positive integer"
Expand Down
24 changes: 13 additions & 11 deletions tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def perform_tests_with(clf, cv_test, stopwords=True):
y_pred = clf.predict([doc_unknown], def_cat=STR_UNKNOWN)
assert y_pred[0] == STR_UNKNOWN_CATEGORY
y_pred = clf.predict([doc_unknown], multilabel=True)
assert y_pred[0] == [most_prob_cat]
y_pred = clf.predict([doc_unknown], def_cat=STR_UNKNOWN, multilabel=True, labels=False)
assert y_pred[0] == [IDX_UNKNOWN_CATEGORY]
assert y_pred[0] == []
y_pred = clf.predict([doc_unknown], def_cat=STR_MOST_PROBABLE, multilabel=True, labels=False)
assert y_pred[0] == [most_prob_cat_idx]

y_pred = clf.predict([doc_unknown], def_cat=STR_MOST_PROBABLE)
assert y_pred[0] == most_prob_cat
Expand Down Expand Up @@ -175,16 +175,17 @@ def perform_tests_with(clf, cv_test, stopwords=True):
assert len(multilabel_labels) == len(r)
assert r[0] in multilabel_idxs and r[1] in multilabel_idxs

assert clf.classify_multilabel('') == [most_prob_cat]
assert clf.classify_multilabel('', def_cat=STR_UNKNOWN) == [pyss3.STR_UNKNOWN_CATEGORY]
assert clf.classify_multilabel('') == []
assert clf.classify_multilabel('', def_cat=STR_MOST_PROBABLE) == [most_prob_cat]
assert clf.classify_multilabel('', def_cat=def_cat) == [def_cat]

assert clf.classify_multilabel(doc_unknown) == [most_prob_cat]
assert clf.classify_multilabel(doc_unknown, def_cat=STR_UNKNOWN) == [pyss3.STR_UNKNOWN_CATEGORY]
assert clf.classify_multilabel(doc_unknown) == []
assert clf.classify_multilabel(doc_unknown, def_cat=STR_MOST_PROBABLE) == [most_prob_cat]
assert clf.classify_multilabel(doc_unknown, def_cat=def_cat) == [def_cat]

assert clf.classify_multilabel(doc_unknown, labels=False) == [most_prob_cat_idx]
assert clf.classify_multilabel(doc_unknown, def_cat=STR_UNKNOWN, labels=False) == [-1]
assert clf.classify_multilabel(doc_unknown, labels=False) == []
assert clf.classify_multilabel(doc_unknown, def_cat=STR_MOST_PROBABLE,
labels=False) == [most_prob_cat_idx]
assert clf.classify_multilabel(doc_unknown, def_cat=def_cat, labels=False) == [def_cat_idx]

# "learn an doc_unknown and a new category" case
Expand Down Expand Up @@ -249,7 +250,8 @@ def test_pyss3_functions():
r = [(6, 8.1), (7, 5.6), (2, 5.5), (4, 1.5),
(5, 1.3), (3, 1.2), (0, 1.1), (1, 0.4)]
assert pyss3.kmean_multilabel_size(r) == 3
assert pyss3.kmean_multilabel_size([(0, 0), (1, 0)]) == 2
assert pyss3.kmean_multilabel_size([(0, 0), (1, 0)]) == 0
assert pyss3.kmean_multilabel_size([(0, 10), (1, 10)]) == 0

with pytest.raises(IndexError):
pyss3.mad([], 0)
Expand All @@ -272,7 +274,7 @@ def test_multilabel():
clf.fit(x_train, y_train)

assert sorted(clf.get_categories()) == ['insult', 'obscene', 'severe_toxic', 'toxic']
assert clf.classify_multilabel("this is a unknown document!") == ['toxic']
assert clf.classify_multilabel("this is a unknown document!") == []

y_pred = [[], ['toxic'], ['severe_toxic'], ['obscene'], ['insult'], ['toxic', 'insult']]

Expand Down

0 comments on commit 8b2ea60

Please sign in to comment.