Skip to content

Commit

Permalink
Add extract_insight method to SS3 class
Browse files Browse the repository at this point in the history
This new method, given a document, returns the pieces of text that
were involved in the classification decision, along with the conf-
idence values associated with them. Read the documentation for mo-
re info.

This method was created to provide the user with a different way
of checking  what SS3 is actually  learning. Before this  method
the user had to use the "Live Test" tool  to do it. Furthermore,
this method could be used to "summary" the input document, leav-
ing only relevant parts to whom it needs them.
  • Loading branch information
sergioburdisso committed Feb 9, 2020
1 parent 911f3f6 commit eee1e29
Showing 1 changed file with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions pyss3/__init__.py
Expand Up @@ -1593,6 +1593,117 @@ def plot_value_distribution(self, cat):

plt.show()

def extract_insight(
self, doc, cat='auto', level='word', window_size=3, min_cv=0.01, sort=True
):
"""
Get the list of text blocks involved in the classification decision.
Given a document, return the pieces of text that were involved in the
classification decision, along with the confidence values associated
with them. If a category is given, perform the process as if the
given category were the one assigned by the classifier.
:param doc: the content of the document
:type doc: str
:param cat: the category in relation to which text blocks are obtained.
If not present, it will automatically use the category assigned
by SS3 after classification.
Options are 'auto' or a given category name. (default: 'auto')
:type cat: str
:param level: the level at which text blocks are going to be extracted.
Options are 'word', 'sentence' or 'paragraph'. (default: 'word')
:type level: str
:param window_size: the number of words, before and after each identified word,
to be also included along with the identified word. For instance,
``window_size=0`` means return only individual words,
``window_size=1`` means also include the word that was
before and the one that was after them. If multiple selected
words are close enough for their word windows to be overlapping,
then those word windows will be merged into a longer and single one.
This argument is ignored when ``level`` is not equal to 'word'.
(default: 3)
:type window_size: int
:param min_cv: the minimum confidence value each text block must have to be
included in the output. (default 0.01)
:type min_cv: float
:param sort: whether to return the text blocks ordered by their confidence value
or not. If ``sort=False`` then blocks will be returned
following the order they had in the input document. (default: True)
:type sort: bool
:returns: a list of pairs (text, confidence value) containing the text (blocks) involved,
and to what degree (*), in the classification decision.
(*) given by the confidence value
:rtype: list
:raises: InvalidCategoryError, ValueError
"""
r = self.classify(doc, json=True)

if cat == 'auto':
c_i = r["cvns"][0][0]
else:
c_i = self.get_category_index(cat)
if c_i == IDX_UNKNOWN_CATEGORY:
Print.error(
"The excepted values for the `cat` argument are 'auto' "
"or a valid category name, found '%s' instead" % str(cat),
raises=InvalidCategoryError
)

if level == 'paragraph':
insights = [
(
" ".join([word["lexeme"]
for s in p["sents"]
for word in s["words"]]),
p["cv"][c_i]
)
for p in r["pars"]
if p["cv"][c_i] > min_cv
]
elif level == 'sentence':
insights = [
(
" ".join([word["lexeme"]
for word in s["words"]]),
s["cv"][c_i]
)
for p in r["pars"] for s in p["sents"]
if s["cv"][c_i] > min_cv
]
elif level == 'word':
insights = []
for p in r["pars"]:
for s in p["sents"]:
words = s["words"]
w_i = 0
while w_i < len(words):
w = words[w_i]
if w["cv"][c_i] > min_cv:
ww = []
ww_cv = 0
w_i = max(w_i - window_size, 0)
ww_i_end = min(w_i + window_size, len(words) - 1)
while w_i <= ww_i_end:
ww.append(words[w_i]["lexeme"])
ww_cv += words[w_i]["cv"][c_i]
if words[w_i]["cv"][c_i] > min_cv:
ww_i_end = min(w_i + window_size, len(words) - 1)
w_i += 1
insights.append((" ".join(ww), ww_cv))
else:
w_i += 1
else:
raise ValueError(
"expected values for the `level` argument are "
"'word', 'sentence', or 'paragraph', found '%s' instead."
% str(level)
)

if sort:
insights.sort(key=lambda b_cv: -b_cv[1])
return insights

def learn(self, doc, cat, n_grams=1, prep=True, update=True):
"""
Learn a new document for a given category.
Expand Down

0 comments on commit eee1e29

Please sign in to comment.