Skip to content

Commit

Permalink
exposed pyarrow filters in the make_reader and make_batch_reader api (#…
Browse files Browse the repository at this point in the history
…564)

* fixed conflict

* reversed travis changes

* fixed typo

* fixed lint

* fixed tests:

* fixed lint

* Added comment

* addressed comments

* addressed comments

* addressed comments
  • Loading branch information
abditag2 committed Jul 24, 2020
1 parent 0a7f32a commit ae95772
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 11 deletions.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ The reader returns one record at a time. The reader r
size.
------------------------------------------------------------------ -----------------------------------------------------
Predicates passed to ``make_reader`` are evaluated per single row. Predicates passed to ``make_batch_reader`` are evaluated per batch.
------------------------------------------------------------------ -----------------------------------------------------
Can filter parquet file based on the ``filters`` argument. Can filter parquet file based on the ``filters`` argument
================================================================== =====================================================


Expand Down
2 changes: 2 additions & 0 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Thanks to our new contributors: Travis Addair and Ryan (rb-determined-ai).

- Retire support for Python 2.
- `PR 568 <https://github.com/uber/petastorm/pull/568>`_: Added additional kwargs for Spark Dataset Converter.
- `PR 564 <https://github.com/uber/petastorm/pull/564>`_: Expose filters (PyArrow filters) argument in make_reader and make_batch_reader



Release 0.9.2
Expand Down
3 changes: 2 additions & 1 deletion petastorm/arrow_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self, worker_id, publish_func, args):
self._local_cache = args[5]
self._transform_spec = args[6]
self._transformed_schema = args[7]
self._arrow_filters = args[8]

if self._ngram:
raise NotImplementedError('ngrams are not supported by ArrowReaderWorker')
Expand Down Expand Up @@ -129,7 +130,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
self._dataset = pq.ParquetDataset(
self._dataset_path_or_paths,
filesystem=self._filesystem,
validate_schema=False)
validate_schema=False, filters=self._arrow_filters)

if self._dataset.partitions is None:
# When read from parquet file list, the `dataset.partitions` will be None.
Expand Down
3 changes: 2 additions & 1 deletion petastorm/py_dict_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self, worker_id, publish_func, args):
self._split_pieces = args[4]
self._local_cache = args[5]
self._transform_spec = args[6]
self._arrow_filters = args[8]

# We create datasets lazily in the first invocation of 'def process'. This speeds up startup time since
# all Worker constructors are serialized
Expand Down Expand Up @@ -135,7 +136,7 @@ def process(self, piece_index, worker_predicate, shuffle_row_drop_partition):
self._dataset = pq.ParquetDataset(
self._dataset_path,
filesystem=self._filesystem,
validate_schema=False)
validate_schema=False, filters=self._arrow_filters)

piece = self._split_pieces[piece_index]

Expand Down
30 changes: 22 additions & 8 deletions petastorm/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def make_reader(dataset_url,
cache_type='null', cache_location=None, cache_size_limit=None,
cache_row_size_estimate=None, cache_extra_settings=None,
hdfs_driver='libhdfs3',
transform_spec=None):
transform_spec=None,
filters=None):
"""
Creates an instance of Reader for reading Petastorm datasets. A Petastorm dataset is a dataset generated using
:func:`~petastorm.etl.dataset_metadata.materialize_dataset` context manager as explained
Expand Down Expand Up @@ -117,6 +118,9 @@ def make_reader(dataset_url,
:param transform_spec: An instance of :class:`~petastorm.transform.TransformSpec` object defining how a record
is transformed after it is loaded and decoded. The transformation occurs on a worker thread/process (depends
on the ``reader_pool_type`` value).
:param filters: (List[Tuple] or List[List[Tuple]]): Standard PyArrow filters.
These will be applied when loading the parquet file with PyArrow. More information
here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html
:return: A :class:`Reader` object
"""
dataset_url = normalize_dir_url(dataset_url)
Expand Down Expand Up @@ -161,6 +165,7 @@ def make_reader(dataset_url,
'shard_count': shard_count,
'cache': cache,
'transform_spec': transform_spec,
'filters': filters
}

try:
Expand All @@ -187,7 +192,8 @@ def make_batch_reader(dataset_url_or_urls,
cache_type='null', cache_location=None, cache_size_limit=None,
cache_row_size_estimate=None, cache_extra_settings=None,
hdfs_driver='libhdfs3',
transform_spec=None):
transform_spec=None,
filters=None):
"""
Creates an instance of Reader for reading batches out of a non-Petastorm Parquet store.
Expand Down Expand Up @@ -241,6 +247,9 @@ def make_batch_reader(dataset_url_or_urls,
:param transform_spec: An instance of :class:`~petastorm.transform.TransformSpec` object defining how a record
is transformed after it is loaded and decoded. The transformation occurs on a worker thread/process (depends
on the ``reader_pool_type`` value).
:param filters: (List[Tuple] or List[List[Tuple]]): Standard PyArrow filters.
These will be applied when loading the parquet file with PyArrow. More information
here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html
:return: A :class:`Reader` object
"""
dataset_url_or_urls = normalize_dataset_url_or_urls(dataset_url_or_urls)
Expand Down Expand Up @@ -287,7 +296,8 @@ def make_batch_reader(dataset_url_or_urls,
shard_count=shard_count,
cache=cache,
transform_spec=transform_spec,
is_batched_reader=True)
is_batched_reader=True,
filters=filters)


class Reader(object):
Expand All @@ -300,7 +310,7 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
shuffle_row_groups=True, shuffle_row_drop_partitions=1,
predicate=None, rowgroup_selector=None, reader_pool=None, num_epochs=1,
cur_shard=None, shard_count=None, cache=None, worker_class=None,
transform_spec=None, is_batched_reader=False):
transform_spec=None, is_batched_reader=False, filters=None):
"""Initializes a reader object.
:param pyarrow_filesystem: An instance of ``pyarrow.FileSystem`` that will be used. If not specified,
Expand Down Expand Up @@ -336,9 +346,11 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
to the main data store is either slow or expensive and the local machine has large enough storage
to store entire dataset (or a partition of a dataset if shards are used).
By default, use the :class:`.NullCache` implementation.
:param worker_class: This is the class that will be instantiated on a different thread/process. It's
responsibility is to load and filter the data.
:param filters: (List[Tuple] or List[List[Tuple]]): Standard PyArrow filters.
These will be applied when loading the parquet file with PyArrow. More information
here: https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html
"""

# 1. Open the parquet storage (dataset)
Expand All @@ -357,7 +369,8 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
self.is_batched_reader = is_batched_reader
# 1. Resolve dataset path (hdfs://, file://) and open the parquet storage (dataset)
self.dataset = pq.ParquetDataset(dataset_path, filesystem=pyarrow_filesystem,
validate_schema=False, metadata_nthreads=10)
validate_schema=False, metadata_nthreads=10,
filters=filters)

if self.dataset.partitions is None:
# When read from parquet file list, the `dataset.partitions` will be None.
Expand Down Expand Up @@ -412,8 +425,9 @@ def __init__(self, pyarrow_filesystem, dataset_path, schema_fields=None,
self._workers_pool.workers_count + _VENTILATE_EXTRA_ROWGROUPS)

# 5. Start workers pool
self._workers_pool.start(worker_class, (pyarrow_filesystem, dataset_path, storage_schema, self.ngram,
row_groups, cache, transform_spec, self.schema),
self._workers_pool.start(worker_class, (pyarrow_filesystem, dataset_path, storage_schema,
self.ngram, row_groups, cache, transform_spec,
self.schema, filters),
ventilator=self.ventilator)
logger.debug('Workers pool started')

Expand Down
27 changes: 26 additions & 1 deletion petastorm/tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import operator
import os
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -31,7 +33,7 @@
from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.predicates import in_lambda
from petastorm.selectors import SingleIndexSelector, IntersectIndexSelector, UnionIndexSelector
from petastorm.tests.test_common import create_test_dataset, TestSchema
from petastorm.tests.test_common import create_test_dataset, TestSchema, create_test_scalar_dataset
from petastorm.tests.test_end_to_end_predicates_impl import \
PartitionKeyInSetPredicate, EqualPredicate, VectorizedEqualPredicate
from petastorm.unischema import UnischemaField, Unischema
Expand Down Expand Up @@ -857,3 +859,26 @@ def test_make_batch_reader_with_url_list(scalar_dataset):
row_count += len(batch.id)

assert row_count == 100


def test_pyarrow_filters_make_reader(synthetic_dataset):
with make_reader(synthetic_dataset.url, workers_count=5, num_epochs=1,
filters=[('partition_key', '=', 'p_5'), ]) as reader:
uv = set()
for data in reader:
uv.add(data.partition_key)

assert uv == {'p_5'}


def test_pyarrow_filters_make_batch_reader():
path = tempfile.mkdtemp()
url = 'file://' + path
create_test_scalar_dataset(url, 3000, partition_by=['id_div_700'])
with make_batch_reader(url, filters=[('id_div_700', '=', 2), ]) as reader:
uv = set()
for data in reader:
for _id_div_700 in data.id_div_700:
uv.add(_id_div_700)

assert uv == {2}

0 comments on commit ae95772

Please sign in to comment.