Skip to content

Commit

Permalink
Improve predict speed (10x-200x faster!)
Browse files Browse the repository at this point in the history
The `predict` method (and therefore all methods using it, such as
`grid_search`) now runs form 10 to (above) 200 times faster (e.g. from 2
mins to 9s, or even 500ms when used the same `x_test` a second time,
which is the usual case in `grid_search(es)`).

Now, the only bottleneck left to speed up the overall `grid_search` time
is training time, which is something in "the TODO list" (i.e. rewrite
the training algorithm to use numpy and sparse matrices)
  • Loading branch information
sergioburdisso committed Mar 2, 2020
1 parent b4f8827 commit 37202d8
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
147 changes: 147 additions & 0 deletions pyss3/__init__.py
Expand Up @@ -9,6 +9,7 @@
import re
import json
import errno
import numpy as np

from io import open
from time import time
Expand Down Expand Up @@ -97,6 +98,10 @@ class SS3:
__s_update__ = None
__p_update__ = None

__cv_cache__ = None
__last_x_test__ = None
__last_x_test_idx__ = None

__prun_floor__ = 10
__prun_trigger__ = 1000000
__prun_counter__ = 0
Expand Down Expand Up @@ -164,6 +169,10 @@ def __init__(
self.__cv_mode__ = cv_m
self.__sn_mode__ = sn_m

self.original_sumop_ngrams = self.summary_op_ngrams
self.original_sumop_sentences = self.summary_op_sentences
self.original_sumop_paragraphs = self.summary_op_paragraphs

def __lv__(self, ngram, icat, cache=True):
"""Local value function."""
if cache:
Expand Down Expand Up @@ -337,6 +346,12 @@ def __apply_fn__(self, fn, ngram, cat):
if w]
return fn(ngram, icat) if IDX_UNKNOWN_WORD not in ngram else 0

def __summary_ops_are_pristine__(self):
"""Return True if summary operators haven't changed."""
return self.original_sumop_ngrams == self.summary_op_ngrams and \
self.original_sumop_sentences == self.summary_op_sentences and \
self.original_sumop_paragraphs == self.summary_op_paragraphs

def __classify_ngram__(self, ngram):
"""Classify the given n-gram."""
cv = [
Expand Down Expand Up @@ -743,6 +758,105 @@ def __save_cat_vocab__(self, icat, path, n_grams):
f.close()
Print.info("\t[ %s stored in '%s'" % (term, voc_path))

def __update_cv_cache__(self):
"""Update numpy darray confidence values cache."""
if self.__cv_cache__ is None:
self.__cv_cache__ = np.zeros((len(self.__index_to_word__), len(self.__categories__)))
cv = self.__cv__
for term_idx, cv_vec in enumerate(self.__cv_cache__):
for cat_idx, _ in enumerate(cv_vec):
try:
cv_vec[cat_idx] = cv([term_idx], cat_idx)
except KeyError:
cv_vec[cat_idx] = 0

def __predict_fast__(
self, x_test, def_cat=STR_MOST_PROBABLE, labels=True,
multilabel=False, proba=False, prep=True, leave_pbar=True
):
"""A faster version of the `predict` method (using numpy)."""
if not def_cat or def_cat == STR_UNKNOWN:
def_cat = IDX_UNKNOWN_CATEGORY
elif def_cat == STR_MOST_PROBABLE:
def_cat = self.__get_most_probable_category__()
else:
def_cat = self.get_category_index(def_cat)
if def_cat == IDX_UNKNOWN_CATEGORY:
raise InvalidCategoryError

if self.__update_needed__():
self.update_values()

if self.__cv_cache__ is None:
self.__update_cv_cache__()
self.__last_x_test__ = None # could have learned a new word (in `learn`)
cv_cache = self.__cv_cache__

x_test_hash = hash(x_test)
if x_test_hash == self.__last_x_test__:
x_test_idx = self.__last_x_test_idx__
else:
self.__last_x_test__ = x_test_hash
self.__last_x_test_idx__ = [None] * len(x_test)
x_test_idx = self.__last_x_test_idx__
word_index = self.get_word_index
for doc_idx, doc in enumerate(tqdm(x_test, desc="Caching documents",
leave=False, disable=Print.is_quiet())):
x_test_idx[doc_idx] = [
word_index(w)
for w
in re.split(self.__word_delimiter__, Pp.clean_and_ready(doc) if prep else doc)
if word_index(w) != IDX_UNKNOWN_WORD
]

y_pred = [None] * len(x_test)
for doc_idx, doc in enumerate(tqdm(x_test_idx, desc="Classification",
leave=leave_pbar, disable=Print.is_quiet())):
if self.__a__ > 0:
doc_cvs = cv_cache[doc]
doc_cvs[doc_cvs <= self.__a__] = 0
pred_cv = np.add.reduce(doc_cvs, 0)
else:
pred_cv = np.add.reduce(cv_cache[doc], 0)

if proba:
y_pred[doc_idx] = list(pred_cv)
continue

if not multilabel:
if pred_cv.sum() == 0:
y_pred[doc_idx] = def_cat
else:
y_pred[doc_idx] = np.argmax(pred_cv)

if labels:
if y_pred[doc_idx] != IDX_UNKNOWN_CATEGORY:
y_pred[doc_idx] = self.__categories__[y_pred[doc_idx]][NAME]
else:
y_pred[doc_idx] = STR_UNKNOWN_CATEGORY
else:
if pred_cv.sum() == 0:
if def_cat == IDX_UNKNOWN_CATEGORY:
y_pred[doc_idx] = [STR_UNKNOWN_CATEGORY if labels else def_cat]
else:
y_pred[doc_idx] = [self.get_category_name(def_cat) if labels else def_cat]
else:
r = sorted(
[
(i, pred_cv[i])
for i in range(pred_cv.size)
],
key=lambda e: -e[1]
)
if labels:
y_pred[doc_idx] = [
self.get_category_name(cat_i)
for cat_i, _ in r[:kmean_multilabel_size(r)]
]
else:
y_pred[doc_idx] = [cat_i for cat_i, _ in r[:kmean_multilabel_size(r)]]
return y_pred

def summary_op_ngrams(self, cvs):
"""
Summary operator for n-gram confidence vectors.
Expand Down Expand Up @@ -1438,6 +1552,9 @@ def update_values(self, force=False):
self.__l_update__ = self.__l__
self.__p_update__ = self.__p__

if self.__cv_cache__ is not None:
self.__update_cv_cache__()

def print_model_info(self):
"""Print information regarding the model."""
print()
Expand Down Expand Up @@ -1779,6 +1896,8 @@ def learn(self, doc, cat, n_grams=1, prep=True, update=True):
:param update: enables model auto-update after learning (default: True)
:type update: bool
"""
self.__cv_cache__ = None

try:
doc = doc.decode(ENCODING)
except UnicodeEncodeError: # for python 2 compatibility
Expand Down Expand Up @@ -2103,6 +2222,10 @@ def predict_proba(self, x_test, prep=True, leave_pbar=True):
if not self.__categories__:
raise EmptyModelError

if self.get_ngrams_length() == 1 and self.__summary_ops_are_pristine__():
return self.__predict_fast__(x_test, prep=prep,
leave_pbar=leave_pbar, proba=True)

x_test = list(x_test)
classify = self.classify
return [
Expand Down Expand Up @@ -2160,6 +2283,11 @@ def predict(
if self.get_category_index(def_cat) == IDX_UNKNOWN_CATEGORY:
raise InvalidCategoryError

if self.get_ngrams_length() == 1 and self.__summary_ops_are_pristine__():
return self.__predict_fast__(x_test, def_cat=def_cat, labels=labels,
multilabel=multilabel, prep=prep,
leave_pbar=leave_pbar)

stime = time()
Print.info("about to start classifying test documents", offset=1)
classify = self.classify_label if not multilabel else self.classify_multilabel
Expand Down Expand Up @@ -2390,6 +2518,25 @@ def re_split_keep(regex, string):
return re.split(regex, string)


def hash(str_list):
"""
Return a hash value for a given list of string.
:param str_list: a list of strings (e.g. x_test)
:type str_list: list (of str)
:returns: an MD5 hash value
:rtype: str
"""
import hashlib
m = hashlib.md5()
for doc in str_list:
try:
m.update(doc)
except (TypeError, UnicodeEncodeError):
m.update(doc.encode('ascii', 'ignore'))
return m.hexdigest()


def vsum(v0, v1):
"""Vectorial version of sum."""
return [v0[i] + v1[i] for i in xrange(len(v0))]
Expand Down
12 changes: 12 additions & 0 deletions tests/test_pyss3.py
Expand Up @@ -97,13 +97,25 @@ def perform_tests_with(clf, cv_test, stopwords=True):
# predict
y_pred = clf.predict(x_test)
assert y_pred == y_test
clf.set_a(.1)
y_pred = clf.predict(x_test)
assert y_pred == y_test
clf.set_a(0)
y_pred = clf.predict(x_test, multilabel=True)
assert y_pred == [[y] for y in y_test]
y_pred = clf.predict(x_test, multilabel=True, labels=False)
assert y_pred == [[clf.get_category_index(y)] for y in y_test]

y_pred = clf.predict(x_test, labels=False)
y_pred = [clf.get_category_name(ic) for ic in y_pred]
assert y_pred == y_test

y_pred = clf.predict([doc_unknown], def_cat=STR_UNKNOWN)
assert y_pred[0] == STR_UNKNOWN_CATEGORY
y_pred = clf.predict([doc_unknown], multilabel=True)
assert y_pred[0] == [most_prob_cat]
y_pred = clf.predict([doc_unknown], def_cat=STR_UNKNOWN, multilabel=True, labels=False)
assert y_pred[0] == [IDX_UNKNOWN_CATEGORY]

y_pred = clf.predict([doc_unknown], def_cat=STR_MOST_PROBABLE)
assert y_pred[0] == most_prob_cat
Expand Down

0 comments on commit 37202d8

Please sign in to comment.