From 2dd6707443367985c6530379bbaa4a8cea2b2369 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Mon, 7 Dec 2020 11:46:32 +0800 Subject: [PATCH] Support creating shards for Text files (#2390) * Create shards for a csv file * Reader to partition csv files * Create csv reader * rename csv to text * Polish elasticdl job service * Polish elasticdl job service * Move the thread to check the timeout task into task manager * Delete unused imports * Fix conflicts * Pre-commit * Set flake8 * Fix by comments * delete the method to read records * Implement read_records * Fix shards to list --- .flake8 | 4 + elasticdl/python/data/reader/csv_reader.py | 75 ------------------- .../python/data/reader/data_reader_factory.py | 12 ++- elasticdl/python/data/reader/text_reader.py | 72 ++++++++++++++++++ elasticdl/python/tests/data_reader_test.py | 40 +++------- elasticdl/python/tests/test_utils.py | 5 +- 6 files changed, 100 insertions(+), 108 deletions(-) create mode 100644 .flake8 delete mode 100644 elasticdl/python/data/reader/csv_reader.py create mode 100644 elasticdl/python/data/reader/text_reader.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..090829226 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +ignore = E203, E266, W503 +max-line-length = 79 + diff --git a/elasticdl/python/data/reader/csv_reader.py b/elasticdl/python/data/reader/csv_reader.py deleted file mode 100644 index 2c8ebdbde..000000000 --- a/elasticdl/python/data/reader/csv_reader.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2020 The ElasticDL Authors. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import csv - -import numpy as np -import tensorflow as tf - -from elasticdl.python.data.reader.data_reader import ( - AbstractDataReader, - Metadata, - check_required_kwargs, -) - - -class CSVDataReader(AbstractDataReader): - """This reader is used to read data from a csv file. It is convenient for - user to locally run and debug a Keras model by using this reader. - However, it cannot be used with distribution strategy because it cannot - read data by line indices. - """ - - def __init__(self, **kwargs): - """ - Args: - kwargs should contains "sep" and "columns" like - 'sep=",",column=["sepal.length", "sepal.width", "variety"]' - """ - AbstractDataReader.__init__(self, **kwargs) - check_required_kwargs(["sep", "columns"], kwargs) - self.sep = kwargs.get("sep", ",") - self.selected_columns = kwargs.get("columns", None) - - def read_records(self, task): - with open(task.shard.name, "r", encoding="utf-8") as csv_file: - csv_reader = csv.reader(csv_file, delimiter=self.sep) - csv_columns = next(csv_reader) - selected_columns = ( - csv_columns - if self.selected_columns is None - else self.selected_columns - ) - if not set(selected_columns).issubset(set(csv_columns)): - raise ValueError( - "The first line in the csv file must be column names and " - "the selected columns are not in the file. The selected " - "columns are {} and the columns in {} are {}".format( - selected_columns, task.shard.name, csv_columns - ) - ) - column_indices = [csv_columns.index(e) for e in selected_columns] - for line in csv_reader: - line_elements = np.array(line, dtype=np.str) - yield line_elements[column_indices].tolist() - - def create_shards(self): - pass - - @property - def records_output_types(self): - return tf.string - - @property - def metadata(self): - return Metadata(column_names=self.selected_columns) diff --git a/elasticdl/python/data/reader/data_reader_factory.py b/elasticdl/python/data/reader/data_reader_factory.py index e6c67bffb..975ff0187 100644 --- a/elasticdl/python/data/reader/data_reader_factory.py +++ b/elasticdl/python/data/reader/data_reader_factory.py @@ -15,9 +15,9 @@ from elasticdl.python.common.constants import MaxComputeConfig, ReaderType from elasticdl.python.data.odps_io import is_odps_configured -from elasticdl.python.data.reader.csv_reader import CSVDataReader from elasticdl.python.data.reader.odps_reader import ODPSDataReader from elasticdl.python.data.reader.recordio_reader import RecordIODataReader +from elasticdl.python.data.reader.text_reader import TextDataReader def create_data_reader(data_origin, records_per_task=None, **kwargs): @@ -45,11 +45,17 @@ def create_data_reader(data_origin, records_per_task=None, **kwargs): **kwargs, ) elif data_origin and data_origin.endswith(".csv"): - return CSVDataReader(data_dir=data_origin, **kwargs) + return TextDataReader( + filename=data_origin, + records_per_task=records_per_task, + **kwargs, + ) else: return RecordIODataReader(data_dir=data_origin) elif reader_type == ReaderType.CSV_READER: - return CSVDataReader(data_dir=data_origin, **kwargs) + return TextDataReader( + filename=data_origin, records_per_task=records_per_task, **kwargs + ) elif reader_type == ReaderType.ODPS_READER: if not is_odps_configured: raise ValueError( diff --git a/elasticdl/python/data/reader/text_reader.py b/elasticdl/python/data/reader/text_reader.py new file mode 100644 index 000000000..106e91e70 --- /dev/null +++ b/elasticdl/python/data/reader/text_reader.py @@ -0,0 +1,72 @@ +# Copyright 2020 The ElasticDL Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import linecache + +import tensorflow as tf + +from elasticdl.python.data.reader.data_reader import ( + AbstractDataReader, + Metadata, +) + + +class TextDataReader(AbstractDataReader): + """This reader is used to create shards for a file and + read records from the shard. + """ + + def __init__(self, filename, records_per_task, **kwargs): + """ + Args: + kwargs should contains "filename" and "records_per_task". + """ + AbstractDataReader.__init__(self, **kwargs) + self._kwargs = kwargs + self._filename = filename + self._records_per_task = records_per_task + + def read_records(self, task): + records = linecache.getlines(task.shard.name)[ + task.shard.start : task.shard.end + ] + return records + + def create_shards(self): + size = self.get_size() + shards = [] + num_shards = size // self._records_per_task + start_ind = 0 + for shard_id in range(num_shards): + shards.append((self._filename, start_ind, self._records_per_task,)) + start_ind += self._records_per_task + # Create a shard with the last records + num_records_left = size % self._records_per_task + if num_records_left != 0: + shards.append((self._filename, start_ind, num_records_left,)) + return shards + + def get_size(self): + with open(self._filename) as file: + reader = csv.reader(file) + line_num = len(list(reader)) + return line_num + + @property + def records_output_types(self): + return tf.string + + @property + def metadata(self): + return Metadata(column_names=None) diff --git a/elasticdl/python/tests/data_reader_test.py b/elasticdl/python/tests/data_reader_test.py index c27c97d5c..8a8427f9d 100644 --- a/elasticdl/python/tests/data_reader_test.py +++ b/elasticdl/python/tests/data_reader_test.py @@ -26,11 +26,11 @@ from elasticdl.python.common.constants import MaxComputeConfig from elasticdl.python.common.model_utils import load_module from elasticdl.python.data.odps_io import is_odps_configured -from elasticdl.python.data.reader.csv_reader import CSVDataReader from elasticdl.python.data.reader.data_reader import Metadata from elasticdl.python.data.reader.data_reader_factory import create_data_reader from elasticdl.python.data.reader.odps_reader import ODPSDataReader from elasticdl.python.data.reader.recordio_reader import RecordIODataReader +from elasticdl.python.data.reader.text_reader import TextDataReader from elasticdl.python.master.task_manager import _Task from elasticdl.python.tests.test_utils import ( IRIS_TABLE_COLUMN_NAMES, @@ -73,7 +73,7 @@ def test_recordio_data_reader(self): self.assertEqual(len(v.numpy()), 1) -class CSVDataReaderTest(unittest.TestCase): +class TextDataReaderTest(unittest.TestCase): def test_csv_data_reader(self): with tempfile.TemporaryDirectory() as temp_dir_name: num_records = 128 @@ -87,33 +87,17 @@ def test_csv_data_reader(self): iris_file_name = create_iris_csv_file( size=num_records, columns=columns, temp_dir=temp_dir_name ) - csv_data_reader = CSVDataReader(columns=columns, sep=",") - task = _Task( - iris_file_name, 0, num_records, elasticdl_pb2.TRAINING + csv_data_reader = TextDataReader( + filename=iris_file_name, records_per_task=20 ) - - def _gen(): - for record in csv_data_reader.read_records(task): - yield record - - def _feed(dataset, mode, metadata): - def _parse_data(record): - features = tf.strings.to_number(record[0:-1], tf.float32) - label = tf.strings.to_number(record[-1], tf.float32) - return features, label - - dataset = dataset.map(_parse_data) - dataset = dataset.batch(10) - return dataset - - dataset = tf.data.Dataset.from_generator( - _gen, csv_data_reader.records_output_types - ) - dataset = _feed(dataset, None, None) - for features, labels in dataset: - self.assertEqual(features.shape.as_list(), [10, 4]) - self.assertEqual(labels.shape.as_list(), [10]) - break + shards = csv_data_reader.create_shards() + self.assertEqual(len(shards), 7) + task = _Task(iris_file_name, 0, 20, elasticdl_pb2.TRAINING) + record_count = 0 + for record in csv_data_reader.read_records(task): + record_count += 1 + self.assertEqual(csv_data_reader.get_size(), num_records) + self.assertEqual(record_count, 20) @unittest.skipIf( diff --git a/elasticdl/python/tests/test_utils.py b/elasticdl/python/tests/test_utils.py index c6c8f0ba8..99f75360f 100644 --- a/elasticdl/python/tests/test_utils.py +++ b/elasticdl/python/tests/test_utils.py @@ -270,7 +270,7 @@ def create_recordio_file(size, dataset_name, shape, temp_dir=None): return temp_file.name -def create_iris_csv_file(size, columns, temp_dir=None): +def create_iris_csv_file(size, columns, with_heads=False, temp_dir=None): """Creates a temporary CSV file. Args: @@ -291,7 +291,8 @@ def create_iris_csv_file(size, columns, temp_dir=None): csv_file_name = temp_file.name + ".csv" with open(csv_file_name, "w", newline="") as csv_file: csv_writer = csv.writer(csv_file) - csv_writer.writerow(columns) + if with_heads: + csv_writer.writerow(columns) csv_writer.writerows(value_data) return csv_file_name