Skip to content

Commit

Permalink
support partition format (#1548)
Browse files Browse the repository at this point in the history
* support partition format

* fix black

* add test and rewrite wkargs get

* change default value to None
  • Loading branch information
skydoorkai authored and QiJune committed Dec 4, 2019
1 parent a82c118 commit f105b0d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
10 changes: 8 additions & 2 deletions elasticdl/python/common/model_utils.py
Expand Up @@ -38,8 +38,14 @@ def get_dict_from_params_str(params_str):
kvs = params_str.split(";")
params_dict = {}
for kv in kvs:
k, v = kv.strip().split("=")
params_dict[k] = eval(v)
splitted = kv.strip().split("=")
k = splitted[0]
# if there is '=' in value, need to restore it.
v = "=".join(splitted[1:])
try:
params_dict[k] = eval(v)
except Exception:
params_dict[k] = v
return params_dict
else:
return None
Expand Down
9 changes: 8 additions & 1 deletion elasticdl/python/master/master.py
Expand Up @@ -19,6 +19,7 @@
from elasticdl.python.common.log_utils import get_logger
from elasticdl.python.common.model_handler import ModelHandler
from elasticdl.python.common.model_utils import (
get_dict_from_params_str,
get_module_file_path,
load_model_from_module,
load_module,
Expand All @@ -37,13 +38,18 @@ def _make_task_dispatcher(
prediction_data,
records_per_task,
num_epochs,
data_reader_params,
):
# TODO: Support any subclasses of `AbstractDataReader`
# and support passing specified parameters to the constructor
def _maybe_create_shards(data_origin):
wkargs = get_dict_from_params_str(data_reader_params)
partition = wkargs.get("partition", None) if wkargs else None
return (
create_data_reader(
data_origin=data_origin, records_per_task=records_per_task
data_origin=data_origin,
records_per_task=records_per_task,
partition=partition,
).create_shards()
if data_origin
else {}
Expand Down Expand Up @@ -92,6 +98,7 @@ def __init__(self, args):
args.prediction_data,
records_per_task,
args.num_epochs,
args.data_reader_params,
)

saved_model_path = args.output
Expand Down
4 changes: 4 additions & 0 deletions elasticdl/python/tests/model_utils_test.py
Expand Up @@ -92,6 +92,10 @@ def test_get_dict_from_params_str(self):
get_dict_from_params_str('ls=["a", "b"]; d={"a": 3}'),
{"ls": ["a", "b"], "d": {"a": 3}},
)
self.assertEqual(
get_dict_from_params_str('ls=["a", "b"];partition=dt=20190011'),
{"ls": ["a", "b"], "partition": "dt=20190011"},
)
self.assertEqual(get_dict_from_params_str(""), None)


Expand Down

0 comments on commit f105b0d

Please sign in to comment.