From 856c844f1261fd26aa46641d0ca622ea5a457534 Mon Sep 17 00:00:00 2001 From: brightcoder01 <55301748+brightcoder01@users.noreply.github.com> Date: Thu, 5 Dec 2019 20:13:01 +0800 Subject: [PATCH] Add the recordio gen for heart dataset (#1549) * Add the recordio gen for heart dataset * Add recordio gen for heart dataset --- .isort.cfg | 2 +- elasticdl/docker/Dockerfile.dev | 4 + .../data/recordio_gen/heart_recordio_gen.py | 134 ++++++++++++++++++ elasticdl/requirements-dev.txt | 4 +- 4 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 elasticdl/python/data/recordio_gen/heart_recordio_gen.py diff --git a/.isort.cfg b/.isort.cfg index e66677104..0b31d3fef 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] multi_line_output=3 line_length=79 -known_third_party = PIL,docker,google,grpc,kubernetes,numpy,odps,pyspark,recordio,requests,setuptools,tensorflow,yaml +known_third_party = PIL,docker,google,grpc,kubernetes,numpy,odps,pandas,pyspark,recordio,requests,setuptools,sklearn,tensorflow,yaml include_trailing_comma=True diff --git a/elasticdl/docker/Dockerfile.dev b/elasticdl/docker/Dockerfile.dev index 90c52930c..335b15eac 100644 --- a/elasticdl/docker/Dockerfile.dev +++ b/elasticdl/docker/Dockerfile.dev @@ -25,4 +25,8 @@ RUN python /var/image_label.py --dataset mnist --fraction 0.15 \ COPY elasticdl/python/data/recordio_gen/frappe_recordio_gen.py /var/frappe_recordio_gen.py RUN python /var/frappe_recordio_gen.py --data /root/.keras/datasets --output_dir /data/frappe \ --fraction 0.05 +# Copy heart dataset +COPY elasticdl/python/data/recordio_gen/heart_recordio_gen.py /var/heart_recordio_gen.py +RUN python /var/heart_recordio_gen.py --data_dir /root/.keras/datasets --output_dir /data/heart + RUN rm -rf /root/.keras/datasets diff --git a/elasticdl/python/data/recordio_gen/heart_recordio_gen.py b/elasticdl/python/data/recordio_gen/heart_recordio_gen.py new file mode 100644 index 000000000..bea72f736 --- /dev/null +++ b/elasticdl/python/data/recordio_gen/heart_recordio_gen.py @@ -0,0 +1,134 @@ +import argparse +import os +import pathlib +import sys +import urllib + +import pandas as pd +import recordio +import tensorflow as tf +from sklearn.model_selection import train_test_split + +URL = "https://storage.googleapis.com/applied-dl/heart.csv" + + +def convert_series_to_tf_feature(data_series, columns, dtype_series): + """ + Convert pandas series to TensorFlow features. + Args: + data_series: Pandas series of data content. + columns: Column name array. + dtype_series: Pandas series of dtypes. + Return: + A dict of feature name -> tf.train.Feature + """ + features = {} + for column_name in columns: + feature = None + value = data_series[column_name] + dtype = dtype_series[column_name] + + if dtype == "int64": + feature = tf.train.Feature( + int64_list=tf.train.Int64List(value=[value]) + ) + elif dtype == "float64": + feature = tf.train.Feature( + float_list=tf.train.FloatList(value=[value]) + ) + elif dtype == "str": + feature = tf.train.Feature( + bytes_list=tf.train.BytesList(value=[value.encode("utf-8")]) + ) + elif dtype == "object": + feature = tf.train.Feature( + bytes_list=tf.train.BytesList( + value=[str(value).encode("utf-8")] + ) + ) + else: + assert False, "Unrecoginize dtype: {}".format(dtype) + + features[column_name] = feature + + return features + + +def convert_to_recordio_files(data_frame, dir_name, records_per_shard): + """ + Convert a pandas DataFrame to recordio files. + Args: + data_frame: A pandas DataFrame to convert_to_recordio_files. + dir_name: A directory to put the generated recordio files. + records_per_shard: The record number per shard. + """ + pathlib.Path(dir_name).mkdir(parents=True, exist_ok=True) + + row_num = 0 + writer = None + for index, row in data_frame.iterrows(): + if row_num % records_per_shard == 0: + if writer: + writer.close() + + shard = row_num // records_per_shard + file_path_name = os.path.join(dir_name, "data-%05d" % shard) + writer = recordio.Writer(file_path_name) + + feature = convert_series_to_tf_feature( + row, data_frame.columns, data_frame.dtypes + ) + result_string = tf.train.Example( + features=tf.train.Features(feature=feature) + ).SerializeToString() + writer.write(result_string) + + row_num += 1 + + if writer: + writer.close() + + print("Finish data conversion in {}".format(dir_name)) + + +def load_raw_data(data_dir): + file_name = os.path.basename(URL) + file_path = os.path.join(data_dir, file_name) + pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True) + if not os.path.exists(file_path): + urllib.request.urlretrieve(URL, file_path) + return pd.read_csv(file_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + help="The cache directory to put the data downloaded from the web", + ) + parser.add_argument( + "--records_per_shard", + type=int, + default=128, + help="Record number per shard", + ) + parser.add_argument( + "--output_dir", help="The directory for the generated recordio files" + ) + + args = parser.parse_args(sys.argv[1:]) + + data_frame = load_raw_data(args.data_dir) + + train, test = train_test_split(data_frame, test_size=0.2) + train, val = train_test_split(train, test_size=0.2) + + convert_to_recordio_files( + train, os.path.join(args.output_dir, "train"), args.records_per_shard + ) + convert_to_recordio_files( + val, os.path.join(args.output_dir, "val"), args.records_per_shard + ) + convert_to_recordio_files( + test, os.path.join(args.output_dir, "test"), args.records_per_shard + ) diff --git a/elasticdl/requirements-dev.txt b/elasticdl/requirements-dev.txt index 2e35de35d..ea4670f67 100644 --- a/elasticdl/requirements-dev.txt +++ b/elasticdl/requirements-dev.txt @@ -2,4 +2,6 @@ pytest pytest-cov mock Pillow -pre-commit \ No newline at end of file +pre-commit +pandas +sklearn \ No newline at end of file