Skip to content

Commit

Permalink
Add the recordio gen for heart dataset (#1549)
Browse files Browse the repository at this point in the history
* Add the recordio gen for heart dataset

* Add recordio gen for heart dataset
  • Loading branch information
brightcoder01 committed Dec 5, 2019
1 parent 7dfe15d commit 856c844
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
@@ -1,5 +1,5 @@
[settings]
multi_line_output=3
line_length=79
known_third_party = PIL,docker,google,grpc,kubernetes,numpy,odps,pyspark,recordio,requests,setuptools,tensorflow,yaml
known_third_party = PIL,docker,google,grpc,kubernetes,numpy,odps,pandas,pyspark,recordio,requests,setuptools,sklearn,tensorflow,yaml
include_trailing_comma=True
4 changes: 4 additions & 0 deletions elasticdl/docker/Dockerfile.dev
Expand Up @@ -25,4 +25,8 @@ RUN python /var/image_label.py --dataset mnist --fraction 0.15 \
COPY elasticdl/python/data/recordio_gen/frappe_recordio_gen.py /var/frappe_recordio_gen.py
RUN python /var/frappe_recordio_gen.py --data /root/.keras/datasets --output_dir /data/frappe \
--fraction 0.05
# Copy heart dataset
COPY elasticdl/python/data/recordio_gen/heart_recordio_gen.py /var/heart_recordio_gen.py
RUN python /var/heart_recordio_gen.py --data_dir /root/.keras/datasets --output_dir /data/heart

RUN rm -rf /root/.keras/datasets
134 changes: 134 additions & 0 deletions elasticdl/python/data/recordio_gen/heart_recordio_gen.py
@@ -0,0 +1,134 @@
import argparse
import os
import pathlib
import sys
import urllib

import pandas as pd
import recordio
import tensorflow as tf
from sklearn.model_selection import train_test_split

URL = "https://storage.googleapis.com/applied-dl/heart.csv"


def convert_series_to_tf_feature(data_series, columns, dtype_series):
"""
Convert pandas series to TensorFlow features.
Args:
data_series: Pandas series of data content.
columns: Column name array.
dtype_series: Pandas series of dtypes.
Return:
A dict of feature name -> tf.train.Feature
"""
features = {}
for column_name in columns:
feature = None
value = data_series[column_name]
dtype = dtype_series[column_name]

if dtype == "int64":
feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=[value])
)
elif dtype == "float64":
feature = tf.train.Feature(
float_list=tf.train.FloatList(value=[value])
)
elif dtype == "str":
feature = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[value.encode("utf-8")])
)
elif dtype == "object":
feature = tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[str(value).encode("utf-8")]
)
)
else:
assert False, "Unrecoginize dtype: {}".format(dtype)

features[column_name] = feature

return features


def convert_to_recordio_files(data_frame, dir_name, records_per_shard):
"""
Convert a pandas DataFrame to recordio files.
Args:
data_frame: A pandas DataFrame to convert_to_recordio_files.
dir_name: A directory to put the generated recordio files.
records_per_shard: The record number per shard.
"""
pathlib.Path(dir_name).mkdir(parents=True, exist_ok=True)

row_num = 0
writer = None
for index, row in data_frame.iterrows():
if row_num % records_per_shard == 0:
if writer:
writer.close()

shard = row_num // records_per_shard
file_path_name = os.path.join(dir_name, "data-%05d" % shard)
writer = recordio.Writer(file_path_name)

feature = convert_series_to_tf_feature(
row, data_frame.columns, data_frame.dtypes
)
result_string = tf.train.Example(
features=tf.train.Features(feature=feature)
).SerializeToString()
writer.write(result_string)

row_num += 1

if writer:
writer.close()

print("Finish data conversion in {}".format(dir_name))


def load_raw_data(data_dir):
file_name = os.path.basename(URL)
file_path = os.path.join(data_dir, file_name)
pathlib.Path(data_dir).mkdir(parents=True, exist_ok=True)
if not os.path.exists(file_path):
urllib.request.urlretrieve(URL, file_path)
return pd.read_csv(file_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
help="The cache directory to put the data downloaded from the web",
)
parser.add_argument(
"--records_per_shard",
type=int,
default=128,
help="Record number per shard",
)
parser.add_argument(
"--output_dir", help="The directory for the generated recordio files"
)

args = parser.parse_args(sys.argv[1:])

data_frame = load_raw_data(args.data_dir)

train, test = train_test_split(data_frame, test_size=0.2)
train, val = train_test_split(train, test_size=0.2)

convert_to_recordio_files(
train, os.path.join(args.output_dir, "train"), args.records_per_shard
)
convert_to_recordio_files(
val, os.path.join(args.output_dir, "val"), args.records_per_shard
)
convert_to_recordio_files(
test, os.path.join(args.output_dir, "test"), args.records_per_shard
)
4 changes: 3 additions & 1 deletion elasticdl/requirements-dev.txt
Expand Up @@ -2,4 +2,6 @@ pytest
pytest-cov
mock
Pillow
pre-commit
pre-commit
pandas
sklearn

0 comments on commit 856c844

Please sign in to comment.