Skip to content

Commit

Permalink
Merge pull request #12 from angrymeir/master
Browse files Browse the repository at this point in the history
Rename SS3 initialization parameter from sn_m to sg_m
  • Loading branch information
sergioburdisso committed Jun 30, 2020
2 parents bc8ec0a + c2bddcb commit 63b80d9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
20 changes: 10 additions & 10 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ class SS3:
term (word or n-grams), options are:
"norm_gv_xai", "norm_gv" and "gv" (default: "norm_gv_xai")
:type cv_m: str
:param sn_m: method used to compute the sanction (sn) function, options
:param sg_m: method used to compute the significance (sg) function, options
are: "vanilla" and "xai" (default: "xai")
:type sn_m: str
:type sg_m: str
"""

__name__ = "model"
Expand Down Expand Up @@ -119,7 +119,7 @@ class SS3:

def __init__(
self, s=None, l=None, p=None, a=None,
name="", cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
name="", cv_m=STR_NORM_GV_XAI, sg_m=STR_XAI
):
"""
Class constructor.
Expand All @@ -140,9 +140,9 @@ def __init__(
term (word or n-grams), options are:
"norm_gv_xai", "norm_gv" and "gv" (default: "norm_gv_xai")
:type cv_m: str
:param sn_m: method used to compute the sanction (sn) function, options
:param sg_m: method used to compute the significance (sg) function, options
are: "vanilla" and "xai" (default: "xai")
:type sn_m: str
:type sg_m: str
:raises: ValueError
"""
self.__name__ = (name or self.__name__).lower()
Expand Down Expand Up @@ -172,13 +172,13 @@ def __init__(
elif cv_m == STR_GV:
self.__cv__ = self.__gv__

if sn_m == STR_XAI:
if sg_m == STR_XAI:
self.__sg__ = self.__sg_xai__
elif sn_m == STR_VANILLA:
elif sg_m == STR_VANILLA:
self.__sg__ = self.__sg_vanilla__

self.__cv_mode__ = cv_m
self.__sn_mode__ = sn_m
self.__sg_mode__ = sg_m

self.original_sumop_ngrams = self.summary_op_ngrams
self.original_sumop_sentences = self.summary_op_sentences
Expand Down Expand Up @@ -1400,7 +1400,7 @@ def save_model(self, path=None):
"__index_to_word__": self.__index_to_word__,
"__word_to_index__": self.__word_to_index__,
"__cv_mode__": self.__cv_mode__,
"__sn_mode__": self.__sn_mode__,
"__sg_mode__": self.__sg_mode__,
"__multilabel__": self.__multilabel__
}

Expand Down Expand Up @@ -1470,7 +1470,7 @@ def load_model(self, path=None):
self.__index_to_word__ = jmodel["__index_to_word__"]
self.__word_to_index__ = jmodel["__word_to_index__"]
self.__cv_mode__ = jmodel["__cv_mode__"]
self.__sn_mode__ = jmodel["__sn_mode__"]
self.__sg_mode__ = jmodel["__sg_mode__"]
self.__multilabel__ = jmodel["__multilabel__"] if "__multilabel__" in jmodel else False

self.__zero_cv__ = (0,) * len(self.__categories__)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def test_pyss3_ss3(mockers):
# training different cases
clf = SS3(
s=.45, l=.5, p=1, a=0,
cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
cv_m=STR_NORM_GV_XAI, sg_m=STR_XAI
)
clf.set_name("test")

Expand All @@ -333,7 +333,7 @@ def test_pyss3_ss3(mockers):
clf.train([], [])

# train and predict/classify tests (model: terms are single words)
# cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
# cv_m=STR_NORM_GV_XAI, sg_m=STR_XAI
clf.fit(x_train, y_train)

assert clf.get_ngrams_length() == 1
Expand All @@ -347,30 +347,30 @@ def test_pyss3_ss3(mockers):
perform_tests_on(clf.cv, 0, "video games", "science&technology")
perform_tests_on(clf.gv, 0, "video games", "science&technology")

# cv_m=STR_NORM_GV, sn_m=STR_XAI
# cv_m=STR_NORM_GV, sg_m=STR_XAI
clf = SS3(
s=.45, l=.5, p=1, a=0, name="test-norm-gv-sn-xai",
cv_m=STR_NORM_GV, sn_m=STR_XAI
cv_m=STR_NORM_GV, sg_m=STR_XAI
)
clf.fit(x_train, y_train)

perform_tests_with(clf, [.00114, .00294, 0, 0, 0, .00016, .01878, 8.43969])
perform_tests_on(clf.cv, 0.4307)

# cv_m=STR_GV, sn_m=STR_XAI
# cv_m=STR_GV, sg_m=STR_XAI
clf = SS3(
s=.45, l=.5, p=1, a=0, name="test-gv-sn-xai",
cv_m=STR_GV, sn_m=STR_XAI
cv_m=STR_GV, sg_m=STR_XAI
)
clf.fit(x_train, y_train)

perform_tests_with(clf, [.00062, .00109, 0, 0, 0, .00014, .01878, 6.31605])
assert clf.cv("chicken", "food") == clf.gv("chicken", "food")

# cv_m=STR_NORM_GV_XAI, sn_m=STR_VANILLA
# cv_m=STR_NORM_GV_XAI, sg_m=STR_VANILLA
clf = SS3(
s=.45, l=.5, p=1, a=0, name="test-norm-gv-xai-sn-vanilla",
cv_m=STR_NORM_GV_XAI, sn_m=STR_VANILLA
cv_m=STR_NORM_GV_XAI, sg_m=STR_VANILLA
)
clf.fit(x_train, y_train)

Expand All @@ -379,7 +379,7 @@ def test_pyss3_ss3(mockers):
# train and predict/classify tests (model: terms are word n-grams)
clf = SS3(
name="test-3grams",
cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
cv_m=STR_NORM_GV_XAI, sg_m=STR_XAI
)

clf.fit(x_train, y_train, n_grams=3)
Expand Down

0 comments on commit 63b80d9

Please sign in to comment.