Skip to content

Commit

Permalink
Add custom preprocessing to Live Test Tool (#3)
Browse files Browse the repository at this point in the history
Resolves: #3
  • Loading branch information
sergioburdisso committed May 5, 2020
1 parent 26fff88 commit b50cfaf
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 37 deletions.
2 changes: 1 addition & 1 deletion examples/extract_insight.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"we can see that, unlike the previous ones, these fragments focus less on health-related aspects and much more on science/scientific ones, SS3 even gave us the Method, Objective and Conclusion well-known sections of research papers. For instance, if we read the first fragment without any context, \"Method: This study used a parallel randomized control group design to compare pre-test and post\", we as humans, can clearly see it is related to science."
"we can see that, unlike the previous ones, these fragments focus less on health-related aspects and much more on science/scientific ones, SS3 even gave us the Method and Objective well-known sections of research papers. For instance, if we read the first fragment without any context, \"Method: This study used a parallel randomized control group design to compare pre-test and post\", we as humans, can clearly see it is related to science."
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions examples/topic_categorization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@
"metadata": {},
"outputs": [],
"source": [
"clf.set_hyperparameters(s=0.32, l=1.62, p=2.55)"
"clf.set_hyperparameters(s=0.32, l=1.62, p=2.35)"
]
},
{
Expand Down Expand Up @@ -421,7 +421,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"And that's how we found out that these hyperparameter values (``s=0.32, l=1.62, p=2.55``) were going to improve our classifier accuracy."
"And that's how we found out that these hyperparameter values (``s=0.32, l=1.62, p=2.35``) were going to improve our classifier accuracy."
]
},
{
Expand Down
72 changes: 42 additions & 30 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def __classify_ngram__(self, ngram):
cv[:] = [(v if v > self.__a__ else 0) for v in cv]
return cv

def __classify_sentence__(self, sent, prep, json=False):
def __classify_sentence__(self, sent, prep, json=False, prep_func=None):
"""Classify the given sentence."""
classify_trans = self.__classify_ngram__
cats = xrange(len(self.__categories__))
Expand All @@ -374,7 +374,8 @@ def __classify_sentence__(self, sent, prep, json=False):
if not json:
regex = "%s|[^%s]+" % (word_regex, word_delimiter)
if prep:
sent = Pp.clean_and_ready(sent)
prep_func = prep_func or Pp.clean_and_ready
sent = prep_func(sent)
sent_words = [
(w, w)
for w in re.findall(regex, sent)
Expand All @@ -383,7 +384,7 @@ def __classify_sentence__(self, sent, prep, json=False):
else:
if prep:
sent_words = [
(w, Pp.clean_and_ready(w, dots=False))
(w, Pp.clean_and_ready(w, dots=False) if prep_func is None else prep_func(w))
for w in re_split_keep(word_regex, sent)
if w
]
Expand Down Expand Up @@ -499,11 +500,11 @@ def __classify_sentence__(self, sent, prep, json=False):
"wmv": reduce(vmax, [v["cv"] for v in info]) # word max value
}

def __classify_paragraph__(self, parag, prep, json=False):
def __classify_paragraph__(self, parag, prep, json=False, prep_func=None):
"""Classify the given paragraph."""
if not json:
sents_cvs = [
self.__classify_sentence__(sent, prep=prep)
self.__classify_sentence__(sent, prep=prep, prep_func=prep_func)
for sent in re.split(self.__sent_delimiter__, parag)
if sent
]
Expand All @@ -512,7 +513,7 @@ def __classify_paragraph__(self, parag, prep, json=False):
return self.__zero_cv__
else:
info = [
self.__classify_sentence__(sent, prep=prep, json=True)
self.__classify_sentence__(sent, prep=prep, prep_func=prep_func, json=True)
for sent in re_split_keep(self.__sent_delimiter__, parag)
if sent
]
Expand Down Expand Up @@ -1807,6 +1808,7 @@ def extract_insight(
:raises: InvalidCategoryError, ValueError
"""
r = self.classify(doc, json=True)
word_regex = self.__word_regex__

if cat == 'auto':
c_i = r["cvns"][0][0]
Expand Down Expand Up @@ -1861,7 +1863,7 @@ def extract_insight(
if words[w_i]["cv"][c_i] > min_cv:
ww_left += min(ww_size, (len(words) - 1) - w_i)

if re.search(r"[\w\d]+", words[w_i]["lexeme"]):
if re.search(word_regex, words[w_i]["lexeme"]):
ww_left -= 1

w_i += 1
Expand Down Expand Up @@ -1991,7 +1993,7 @@ def learn(self, doc, cat, n_grams=1, prep=True, update=True):
if update:
self.update_values(force=True)

def classify(self, doc, prep=True, sort=True, json=False):
def classify(self, doc, prep=True, sort=True, json=False, prep_func=None):
"""
Classify a given document.
Expand All @@ -2003,6 +2005,11 @@ def classify(self, doc, prep=True, sort=True, json=False):
:type sort: bool
:param json: return a debugging version of the result in JSON format.
:type json: bool
:param prep_func: the custom preprocessing function to be applied to
the given document before classifying it.
If not given, the default preprocessing function will
be used (as long as ``prep=True``)
:type prep_func: function
:returns: the document confidence vector if ``sort`` is False.
If ``sort`` is True, a list of pairs
(category index, confidence value) ordered by confidence value.
Expand All @@ -2025,7 +2032,7 @@ def classify(self, doc, prep=True, sort=True, json=False):

if not json:
paragraphs_cvs = [
self.__classify_paragraph__(parag, prep=prep)
self.__classify_paragraph__(parag, prep=prep, prep_func=prep_func)
for parag in re.split(self.__parag_delimiter__, doc)
if parag
]
Expand All @@ -2044,7 +2051,7 @@ def classify(self, doc, prep=True, sort=True, json=False):
return cv
else:
info = [
self.__classify_paragraph__(parag, prep=prep, json=True)
self.__classify_paragraph__(parag, prep=prep, prep_func=prep_func, json=True)
for parag in re_split_keep(self.__parag_delimiter__, doc)
if parag
]
Expand Down Expand Up @@ -2316,10 +2323,11 @@ def cv(self, ngram, cat):
[*] the gv function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf
Example
>>> clf.cv("chicken", "food")
>>> clf.cv("roast chicken", "food")
>>> clf.cv("chicken", "sports")
Examples:
>>> clf.cv("chicken", "food")
>>> clf.cv("roast chicken", "food")
>>> clf.cv("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
Expand All @@ -2338,10 +2346,11 @@ def gv(self, ngram, cat):
(gv function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.gv("chicken", "food")
>>> clf.gv("roast chicken", "food")
>>> clf.gv("chicken", "sports")
Examples:
>>> clf.gv("chicken", "food")
>>> clf.gv("roast chicken", "food")
>>> clf.gv("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
Expand All @@ -2360,10 +2369,11 @@ def lv(self, ngram, cat):
(lv function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.lv("chicken", "food")
>>> clf.lv("roast chicken", "food")
>>> clf.lv("chicken", "sports")
Examples:
>>> clf.lv("chicken", "food")
>>> clf.lv("roast chicken", "food")
>>> clf.lv("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
Expand All @@ -2382,10 +2392,11 @@ def sg(self, ngram, cat):
(sg function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.sg("chicken", "food")
>>> clf.sg("roast chicken", "food")
>>> clf.sg("chicken", "sports")
Examples:
>>> clf.sg("chicken", "food")
>>> clf.sg("roast chicken", "food")
>>> clf.sg("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
Expand All @@ -2404,10 +2415,11 @@ def sn(self, ngram, cat):
(sn function is defined in Section 3.2.2 of the original paper:
https://arxiv.org/pdf/1905.08772.pdf)
Example
>>> clf.sn("chicken", "food")
>>> clf.sn("roast chicken", "food")
>>> clf.sn("chicken", "sports")
Examples:
>>> clf.sn("chicken", "food")
>>> clf.sn("roast chicken", "food")
>>> clf.sn("chicken", "sports")
:param ngram: the word or word n-gram
:type ngram: str
Expand Down
18 changes: 14 additions & 4 deletions pyss3/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class Server:
__x_test__ = None
__test_path__ = None
__folder_label__ = None
__preprocess__ = None

@staticmethod
def __send_as_json__(sock, data):
Expand Down Expand Up @@ -229,7 +230,10 @@ def __do_classify__(sock, doc):
Print.show("\t%s[...]" % doc[:50])
Server.__send_as_json__(
sock,
Server.__clf__.classify(doc, json=True)
Server.__clf__.classify(
doc,
prep_func=Server.__preprocess__,
json=True)
)
Print.info("sending classification result...")
except Exception as e:
Expand Down Expand Up @@ -331,7 +335,7 @@ def __load_testset_from_files__():
with open(
file_path, "r", encoding=ENCODING
) as fdoc:
r = classify(fdoc.read())
r = classify(fdoc.read(), prep_func=Server.__preprocess__)
Server.__docs__[cat]["clf_result"].append(
r[0][0] if r[0][1] else unkwon_cat_i
)
Expand Down Expand Up @@ -387,7 +391,7 @@ def set_testset(x_test, y_test):
Server.__docs__[cat]["file"].append(doc_name)
Server.__docs__[cat]["path"].append(":x_test:%d" % idoc)

r = classify(doc)
r = classify(doc, prep_func=Server.__preprocess__)
Server.__docs__[cat]["clf_result"].append(
r[0][0] if r[0][1] else unkwon_cat_i
)
Expand Down Expand Up @@ -463,7 +467,7 @@ def start_listening(port=0):

@staticmethod
def serve(
clf=None, x_test=None, y_test=None, port=0, browser=True, quiet=True
clf=None, x_test=None, y_test=None, port=0, browser=True, quiet=True, prep_func=None
):
"""
Wait for classification requests and serve them.
Expand All @@ -482,8 +486,14 @@ def serve(
:param quiet: if True, use quiet mode. Otherwise use verbose mode
(default: False)
:type quiet: bool
:param prep_func: the custom preprocessing function to be applied to
the given document before classifying it.
If not given, the default preprocessing function will
be used
:type prep_func: function
"""
Server.__clf__ = clf or Server.__clf__
Server.__preprocess__ = prep_func

if not Server.__clf__:
Print.error("a model must be given before serving")
Expand Down

0 comments on commit b50cfaf

Please sign in to comment.