Skip to content
Permalink
Browse files

Provide default implementation of dataset_fn for ODPS data source (#1531

)

* Provide default implementation of dataset_fn for ODPS data source

Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>

* Fix if-else logic

Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>

* Fix test

Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>

* Address comments

Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>

* , -> ;

Signed-off-by: terrytangyuan <terrytangyuan@gmail.com>
  • Loading branch information
terrytangyuan committed Nov 27, 2019
1 parent ee304f1 commit 62b255a918df5b6594c888b19aebbcc74bbce6e4
@@ -1,6 +1,7 @@
import importlib.util
import os

from elasticdl.python.common.constants import ODPSConfig
from elasticdl.python.common.log_utils import default_logger as logger
from elasticdl.python.worker.prediction_outputs_processor import (
BasePredictionOutputsProcessor,
@@ -123,9 +124,22 @@ def get_model_spec(
"inherited from BasePredictionOutputsProcessor. "
"Prediction outputs may not be processed correctly."
)

# If ODPS data source is used, dataset_fn is optional
dataset_fn_required = not all(
k in os.environ
for k in (
ODPSConfig.PROJECT_NAME,
ODPSConfig.ACCESS_ID,
ODPSConfig.ACCESS_KEY,
)
)

return (
model,
_get_spec_value(dataset_fn, model_zoo, default_module, required=True),
_get_spec_value(
dataset_fn, model_zoo, default_module, required=dataset_fn_required
),
_get_spec_value(loss, model_zoo, default_module, required=True),
_get_spec_value(optimizer, model_zoo, default_module, required=True),
_get_spec_value(
@@ -5,7 +5,7 @@
import recordio
import tensorflow as tf

from elasticdl.python.common.constants import ODPSConfig
from elasticdl.python.common.constants import Mode, ODPSConfig
from elasticdl.python.data.odps_io import ODPSReader


@@ -165,6 +165,60 @@ def _get_reader(self, table_name):
def _get_odps_table_name(shard_name):
return shard_name.split(":")[0]

def default_dataset_fn(self):
_check_required_kwargs(["label_col"], self._kwargs)

def dataset_fn(dataset, mode, metadata):
def _parse_data(record):
label_col_name = self._kwargs["label_col"]
record = tf.strings.to_number(record, tf.float32)

def _get_features_without_labels(
record, label_col_idx, features_shape
):
features = [
record[:label_col_idx],
record[label_col_idx + 1 :], # noqa: E203
]
features = tf.concat(features, -1)
return tf.reshape(features, features_shape)

features_shape = (len(metadata.column_names) - 1, 1)
labels_shape = (1,)
if mode == Mode.PREDICTION:
if label_col_name in metadata.column_names:
label_col_idx = metadata.column_names.index(
label_col_name
)
return _get_features_without_labels(
record, label_col_idx, features_shape
)
else:
return tf.reshape(record, features_shape)
else:
if label_col_name not in metadata.column_names:
raise ValueError(
"Missing the label column '%s' in the retrieved "
"ODPS table during %s mode."
% (label_col_name, mode)
)
label_col_idx = metadata.column_names.index(label_col_name)
labels = tf.reshape(record[label_col_idx], labels_shape)
return (
_get_features_without_labels(
record, label_col_idx, features_shape
),
labels,
)

dataset = dataset.map(_parse_data)

if mode == Mode.TRAINING:
dataset = dataset.shuffle(buffer_size=200)
return dataset

return dataset_fn


def create_data_reader(data_origin, records_per_task=None, **kwargs):
if all(
@@ -109,9 +109,12 @@ def test_odps_data_reader_records_reading(self):

def test_create_data_reader(self):
reader = create_data_reader(
data_origin="table", records_per_task=10, **{"columns": ["a", "b"]}
data_origin="table",
records_per_task=10,
**{"columns": ["a", "b"], "label_col": "class"}
)
self.assertEqual(reader._kwargs["columns"], ["a", "b"])
self.assertEqual(reader._kwargs["label_col"], "class")
self.assertEqual(reader._kwargs["records_per_task"], 10)
reader = create_data_reader(data_origin="table", records_per_task=10)
self.assertEqual(reader._kwargs["records_per_task"], 10)
@@ -129,7 +132,12 @@ def test_odps_data_reader_integration_with_local_keras(self):
model = model_spec["custom_model"]()
optimizer = model_spec["optimizer"]()
loss = model_spec["loss"]
dataset_fn = model_spec["dataset_fn"]
reader = create_data_reader(
data_origin="table",
records_per_task=10,
**{"columns": IRIS_TABLE_COLUMN_NAMES, "label_col": "class"}
)
dataset_fn = reader.default_dataset_fn()

def _gen():
for data in self.reader.read_records(
@@ -132,6 +132,18 @@ def _init_from_args(self, args):
args.data_reader_params
),
)
if self._dataset_fn is None:
if hasattr(
self._task_data_service.data_reader, "default_dataset_fn"
):
self._dataset_fn = (
self._task_data_service.data_reader.default_dataset_fn()
)
else:
raise ValueError(
"dataset_fn is required if the data_reader used does "
"not provide default implementation of dataset_fn"
)
self._get_model_steps = args.get_model_steps
if self._get_model_steps > 1:
self._opt = self._opt_fn()
@@ -1,7 +1,5 @@
import tensorflow as tf

from elasticdl.python.common.constants import Mode


def custom_model():
inputs = tf.keras.layers.Input(shape=(4, 1), name="input")
@@ -22,53 +20,6 @@ def optimizer(lr=0.1):
return tf.optimizers.SGD(lr)


def dataset_fn(dataset, mode, metadata):
def _parse_data(record):
label_col_name = "class"
record = tf.strings.to_number(record, tf.float32)

def _get_features_without_labels(
record, label_col_ind, features_shape
):
features = [
record[:label_col_ind],
record[label_col_ind + 1 :], # noqa: E203
]
features = tf.concat(features, -1)
return tf.reshape(features, features_shape)

features_shape = (4, 1)
labels_shape = (1,)
if mode != Mode.PREDICTION:
if label_col_name not in metadata.column_names:
raise ValueError(
"Missing the label column '%s' in the retrieved "
"ODPS table." % label_col_name
)
label_col_ind = metadata.column_names.index(label_col_name)
labels = tf.reshape(record[label_col_ind], labels_shape)
return (
_get_features_without_labels(
record, label_col_ind, features_shape
),
labels,
)
else:
if label_col_name in metadata.column_names:
label_col_ind = metadata.column_names.index(label_col_name)
return _get_features_without_labels(
record, label_col_ind, features_shape
)
else:
return tf.reshape(record, features_shape)

dataset = dataset.map(_parse_data)

if mode == Mode.TRAINING:
dataset = dataset.shuffle(buffer_size=200)
return dataset


def eval_metrics_fn():
return {
"accuracy": lambda labels, predictions: tf.equal(
@@ -81,7 +81,7 @@ elif [[ "$JOB_TYPE" == "odps" ]]; then
--model_zoo=model_zoo \
--model_def=odps_iris_dnn_model.odps_iris_dnn_model.custom_model \
--training_data=$ODPS_TABLE_NAME \
--data_reader_params='columns=["sepal_length", "sepal_width", "petal_length", "petal_width", "class"]' \
--data_reader_params='columns=["sepal_length", "sepal_width", "petal_length", "petal_width", "class"]; label_col="class"' \
--envs="ODPS_PROJECT_NAME=$ODPS_PROJECT_NAME,ODPS_ACCESS_ID=$ODPS_ACCESS_ID,ODPS_ACCESS_KEY=$ODPS_ACCESS_KEY,ODPS_ENDPOINT=" \
--num_epochs=2 \
--master_resource_request="cpu=0.2,memory=1024Mi" \

0 comments on commit 62b255a

Please sign in to comment.
You can’t perform that action at this time.