Skip to content

Commit

Permalink
Add load multilabel dataset from disk in Live Test
Browse files Browse the repository at this point in the history
A new function, ``set_testset_from_files_multilabel()``, was added to
``Live_Test``. It is the multi-label version of
``Live_Test.set_testset_from_files()``.

Besides the "-phl" or "--path-labels" was added to the command-line
version of the live test server.
  • Loading branch information
sergioburdisso committed May 26, 2020
1 parent 6805230 commit 0ddbd6a
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 64 deletions.
3 changes: 2 additions & 1 deletion pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ def __get_def_cat__(self, def_cat):
if def_cat is not None and (def_cat not in [STR_MOST_PROBABLE, STR_UNKNOWN] and
self.get_category_index(def_cat) == IDX_UNKNOWN_CATEGORY):
raise ValueError(
"the default category must be 'most-probable', 'unknown', or a category name."
"the default category must be 'most-probable', 'unknown', or a category name"
" (current value is '%s')." % str(def_cat)
)
def_cat = None if def_cat == STR_UNKNOWN else def_cat
return self.get_most_probable_category() if def_cat == STR_MOST_PROBABLE else def_cat
Expand Down
171 changes: 110 additions & 61 deletions pyss3/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from datetime import datetime

from . import SS3, kmean_multilabel_size, __version__
from .util import is_a_collection, membership_matrix, RecursiveDefaultDict, Print, VERBOSITY
from .util import is_a_collection, membership_matrix, VERBOSITY
from .util import Dataset, Print, RecursiveDefaultDict

import numpy as np
import webbrowser
Expand Down Expand Up @@ -60,6 +61,10 @@ def main():

parser.add_argument('MODEL', help="the model name")
parser.add_argument('-ph', '--path', help="the test set path")
parser.add_argument(
'-phl', '--path-labels', default=None,
help="the path to the labels (only for multilabel)"
)
parser.add_argument(
'-l', '--label', choices=["file", "folder"], default="folder",
help="indicates where to read category labels from"
Expand All @@ -72,6 +77,7 @@ def main():
)
args = parser.parse_args()

Print.set_verbosity(VERBOSITY.VERBOSE)
if args.quiet:
Print.set_verbosity(VERBOSITY.QUIET)

Expand All @@ -90,7 +96,10 @@ def main():

Server.set_model(clf)
if args.path:
Server.set_testset_from_files(args.path, args.label == 'folder')
if args.path_labels is None:
Server.set_testset_from_files(args.path, args.label == 'folder')
else:
Server.set_testset_from_files_multilabel(args.path, args.path_labels)

try:
Server.serve(port=args.port, browser=False, quiet=args.quiet)
Expand Down Expand Up @@ -144,8 +153,11 @@ class Server:

__x_test__ = None
__test_path__ = None
__folder_label__ = None
__test_path_prev__ = None
__labels_path__ = None
__sep_doc__ = None
__sep_label__ = None
__folder_label__ = None
__preprocess__ = None
__default_prep__ = None
__default_cat__ = None
Expand Down Expand Up @@ -262,11 +274,7 @@ def __do_get_info__(sock):
def __do_get_doc__(sock, file):
"""Serve the 'get_doc' message."""
doc = ""
if ":line:" in file:
file, line_n = file.split(":line:")
with open(file, 'r', encoding=ENCODING) as fdoc:
doc = fdoc.read().split(Server.__sep_doc__)[int(line_n)]
elif ":x_test:" in file:
if ":x_test:" in file:
if Server.__x_test__ is not None:
idoc = int(file.split(":x_test:")[1])
doc = Server.__x_test__[idoc]
Expand All @@ -282,67 +290,69 @@ def __clear_testset__():
"""Clear server's test documents."""
Server.__docs__ = RecursiveDefaultDict()
Server.__test_path__ = None
Server.__test_path_prev__ = None
Server.__folder_label__ = None
Server.__x_test__ = None
Server.__labels_path__ = None
Server.__sep_doc__ = None
Server.__sep_label__ = None

@staticmethod
def __load_testset_from_files__():
"""Load the test set files to visualize from ``test_path``."""
Print.info("reading files...")
classify = Server.__clf__.classify
unkwon_cat_i = len(Server.__clf__.get_categories())
if not Server.__folder_label__:
for file in listdir(Server.__test_path__):
file_path = path.join(Server.__test_path__, file)
if path.isfile(file_path):
cat = path.splitext(file)[0]

with open(file_path, "r", encoding=ENCODING) as fcat:
Server.__docs__[cat]["clf_result"] = [
r[0][0] if r[0][1] else unkwon_cat_i
for r in map(
classify,
tqdm(fcat.readlines(),
desc=" Classifying '%s' docs" % cat,
disable=Print.is_quiet())
)
]
n_docs = len(Server.__docs__[cat]["clf_result"])
Server.__docs__[cat]["file"] = [
"doc_%d" % line
for line in range(n_docs)
]
Server.__docs__[cat]["path"] = [
"%s:line:%d" % (file_path, line)
for line in range(n_docs)
]
else:
for cat in listdir(Server.__test_path__):
cat_path = path.join(Server.__test_path__, cat)

if not path.isfile(cat_path):

Server.__docs__[cat]["path"] = []
Server.__docs__[cat]["file"] = []
Server.__docs__[cat]["clf_result"] = []

for file in tqdm(sorted(listdir(cat_path)),
desc=" Classifying '%s' docs" % cat,
disable=Print.is_quiet()):
file_path = path.join(cat_path, file)
if path.isfile(file_path):
Server.__docs__[cat]["path"].append(file_path)
Server.__docs__[cat]["file"].append(file)
with open(
file_path, "r", encoding=ENCODING
) as fdoc:
r = classify(fdoc.read(), prep_func=Server.__preprocess__)
Server.__docs__[cat]["clf_result"].append(
r[0][0] if r[0][1] else unkwon_cat_i
)

Print.info("%d categories found" % len(Server.__docs__))
if Server.__labels_path__:
docs_path = Server.__test_path__
labels_path = Server.__labels_path__
sep_label = Server.__sep_label__
x_test, y_test = Dataset.load_from_files_multilabel(docs_path, labels_path,
sep_label, Server.__sep_doc__)
Server.set_testset(x_test, y_test)
if path.isdir(docs_path):
docs = Server.__docs__['']
Server.__x_test__ = None
sep_label = sep_label or r'\s+' # default separator
with open(labels_path, "r", encoding=ENCODING) as flabels:
doc_raw_names = [re.split(sep_label, l.rstrip())[0]
for l in flabels.read().split('\n')]
doc_i = 0
for doc_name in doc_raw_names:
doc_name += ".txt"
if doc_i == 0 or docs["file"][doc_i - 1] != doc_name:
docs["file"][doc_i] = doc_name
docs["path"][doc_i] = path.join(docs_path, doc_name)
doc_i += 1
else:
classify = Server.__clf__.classify
unkwon_cat_i = len(Server.__clf__.get_categories())
if not Server.__folder_label__:
x_test, y_test = Dataset.load_from_files(Server.__test_path__, False,
sep_doc=Server.__sep_doc__)
Server.set_testset(x_test, y_test)
else:
for cat in listdir(Server.__test_path__):
cat_path = path.join(Server.__test_path__, cat)
if not path.isfile(cat_path):
Server.__docs__[cat]["path"] = []
Server.__docs__[cat]["file"] = []
Server.__docs__[cat]["clf_result"] = []
for file in tqdm(sorted(listdir(cat_path)),
desc=" Classifying '%s' docs" % cat,
disable=Print.is_quiet()):
file_path = path.join(cat_path, file)
if path.isfile(file_path):
Server.__docs__[cat]["path"].append(file_path)
Server.__docs__[cat]["file"].append(file)
with open(
file_path, "r", encoding=ENCODING
) as fdoc:
r = classify(fdoc.read(), prep_func=Server.__preprocess__)
Server.__docs__[cat]["clf_result"].append(
r[0][0] if r[0][1] else unkwon_cat_i
)

Print.info("%d categories found" % len(Server.__docs__))
return len(Server.__docs__) > 0

@staticmethod
Expand Down Expand Up @@ -430,6 +440,7 @@ def set_testset(x_test, y_test=None, def_cat=None):
for labels in y_pred]
t = membership_matrix(clf, y_test_labels).todense()
p = membership_matrix(clf, y_pred, labels=False).todense()
np.seterr(divide='ignore', invalid='ignore')
accuracy = (t & p).sum(axis=1) / (t | p).sum(axis=1)
accuracy[np.isnan(accuracy)] = 1
docs[y_test[0]]["true_labels"] = y_test_labels
Expand All @@ -456,6 +467,7 @@ def set_testset_from_files(test_path, folder_label=True, sep_doc='\n'):
:rtype: bool
"""
Server.__clear_testset__()
Server.__test_path_prev__ = Server.__test_path__
Server.__test_path__ = test_path
Server.__folder_label__ = folder_label
Server.__sep_doc__ = sep_doc
Expand All @@ -477,6 +489,43 @@ def set_testset_from_files(test_path, folder_label=True, sep_doc='\n'):

return docs > 0

@staticmethod
def set_testset_from_files_multilabel(docs_path, labels_path, sep_label=None, sep_doc='\n'):
r"""
Multilabel version of the ``Live_Test.set_testset_from_files()`` function.
Load test documents and category labels from disk to visualize in the Live Test tool.
:param docs_path: the file or the folder containing the test documents.
:type docs_path: str
:param labels_path: the file containing the labels for each document.
* if ``docs_path`` is a file, then the ``labels_path`` file
should contain a line with the corresponding list of category
labels for each document in ``docs_path``.
* if ``docs_path`` is a folder containing the documents, then
the ``labels_path`` file should contain a line for each document and
category label. Each line should have the following format:
``document_name<sep_label>label``.
:type labels_path: str
:param sep_label: the separator/delimiter used to separate either each label (if
``docs_path`` is a file) or the document name from its category
(if ``docs_path`` is a folder).
(default: ``';'`` when ``docs_path`` is a file, the ``'\s+'`` regular
expression otherwise).
:type sep_label: str
:param sep_doc: the separator/delimiter used to separate each document
when loading training/test documents from single file. Valid
only when ``folder_label=False``. (default: ``\n'``)
:type sep_doc: str
"""
Server.__clear_testset__()
Server.__test_path_prev__ = Server.__test_path__
Server.__test_path__ = docs_path
Server.__labels_path__ = labels_path
Server.__sep_doc__ = sep_doc
Server.__sep_label__ = sep_label

@staticmethod
def start_listening(port=0):
"""
Expand Down Expand Up @@ -564,7 +613,7 @@ def serve(
else:
Print.error("y_test must have the same length as x_test")
return
elif Server.__test_path__ and not Server.__docs__:
elif Server.__test_path__ and Server.__test_path_prev__ != Server.__test_path__:
Server.__load_testset_from_files__()

server_socket = Server.__server_socket__
Expand Down
13 changes: 11 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
PYTHON3 = sys.version_info[0] >= 3
DATASET_FOLDER = "dataset"
DATASET_FOLDER_MR = "dataset_mr"
DATASET_MULTILABEL_FOLDER = "dataset_ml"
ADDRESS, PORT = "localhost", None
LT = s.Live_Test

dataset_path = path.join(path.abspath(path.dirname(__file__)), DATASET_FOLDER)
dataset_path_mr = path.join(path.abspath(path.dirname(__file__)), DATASET_FOLDER_MR)
dataset_path_multilabel = path.join(path.abspath(path.dirname(__file__)), DATASET_MULTILABEL_FOLDER)

x_train, y_train = None, None
clf = None
Expand All @@ -47,6 +49,7 @@ class MockCmdLineArgs:
quiet = True
MODEL = "name"
path = dataset_path
path_labels = None
label = 'folder'
port = 0

Expand All @@ -62,7 +65,7 @@ def mockers(mocker):
"parse_args").return_value = MockCmdLineArgs


@pytest.fixture(params=[0, 1, 2, 3, 4, 5, 6, 7])
@pytest.fixture(params=[0, 1, 2, 3, 4, 5, 6, 7, 8])
def test_case(request, mocker):
"""Argument values generator for test_live_test(test_case)."""
mocker.patch("webbrowser.open")
Expand All @@ -73,6 +76,9 @@ def test_case(request, mocker):
LT.set_testset_from_files(dataset_path_mr, folder_label=True)
elif request.param == 2:
LT.set_testset(x_train, y_train)
elif request.param == 8:
LT.set_testset_from_files_multilabel(dataset_path_multilabel + "/train_files",
dataset_path_multilabel + "/file_labels.tsv")
else:
LT.__server_socket__ = None

Expand Down Expand Up @@ -132,7 +138,8 @@ def test_live_test(test_case):
serve_args = {
"x_test": x_train if test_case >= 2 else None,
"y_test": y_train if test_case == 2 else None,
"quiet": test_case != 0
"quiet": test_case != 0,
"browser": test_case == 0
}

if test_case == 4:
Expand All @@ -146,6 +153,8 @@ def test_live_test(test_case):
elif test_case == 7:
serve_args["y_test"] = y_train
serve_args["def_cat"] = 'xxxxx' # raise ValueError
elif test_case == 8:
serve_args["x_test"] = None

if PYTHON3:
threading.Thread(target=LT.serve, kwargs=serve_args, daemon=True).start()
Expand Down

0 comments on commit 0ddbd6a

Please sign in to comment.