Skip to content

Commit

Permalink
Support pyarrow 0.15 API
Browse files Browse the repository at this point in the history
  • Loading branch information
Yevgeni Litvin authored and selitvin committed Oct 19, 2019
1 parent 9b58038 commit 0ddae8f
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 60 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Expand Up @@ -18,8 +18,8 @@ python:
- '2.7'
- '3.6'
env:
- PYARROW_VERSION=0.10.0 TF_VERSION=1.12 DEPLOYABLE=0
- PYARROW_VERSION=0.12.1 TF_VERSION=1.12 DEPLOYABLE=1
- PYARROW_VERSION=0.11.1 TF_VERSION=1.12 DEPLOYABLE=0
- PYARROW_VERSION=0.15.0 TF_VERSION=1.12 DEPLOYABLE=1
services:
- docker

Expand Down
28 changes: 16 additions & 12 deletions petastorm/arrow_reader_worker.py
Expand Up @@ -23,6 +23,7 @@
from pyarrow.parquet import ParquetFile

from petastorm.cache import NullCache
from petastorm.compat import compat_piece_read, compat_table_columns_gen, compat_column_num_chunks
from petastorm.workers_pool import EmptyResultError
from petastorm.workers_pool.worker_base import WorkerBase

Expand All @@ -44,29 +45,36 @@ def read_next(self, workers_pool, schema, ngram):
# Convert arrow table columns into numpy. Strings are handled differently since to_pandas() returns
# numpy array of dtype=object.
result_dict = dict()
for column in result_table.columns:
for column_name, column in compat_table_columns_gen(result_table):
# Assume we get only one chunk since reader worker reads one rowgroup at a time

# `to_pandas` works slower when called on the entire `data` rather directly on a chunk.
if result_table.column(0).data.num_chunks == 1:
if compat_column_num_chunks(result_table.column(0)) == 1:
column_as_pandas = column.data.chunks[0].to_pandas()
else:
column_as_pandas = column.data.to_pandas()

# pyarrow < 0.15.0 would always return a numpy array. Starting 0.15 we get pandas series, hence we
# convert it into numpy array
if isinstance(column_as_pandas, pd.Series):
column_as_numpy = column_as_pandas.as_matrix()
else:
column_as_numpy = column_as_pandas

if pa.types.is_string(column.type):
result_dict[column.name] = column_as_pandas.astype(np.unicode_)
result_dict[column_name] = column_as_numpy.astype(np.unicode_)
elif pa.types.is_list(column.type):
# Assuming all lists are of the same length, hence we can collate them into a matrix
list_of_lists = column_as_pandas
list_of_lists = column_as_numpy
try:
result_dict[column.name] = np.vstack(list_of_lists.tolist())
result_dict[column_name] = np.vstack(list_of_lists.tolist())
except ValueError:
raise RuntimeError('Length of all values in column \'{}\' are expected to be the same length. '
'Got the following set of lengths: \'{}\''
.format(column.name,
.format(column_name,
', '.join(str(value.shape[0]) for value in list_of_lists)))
else:
result_dict[column.name] = column_as_pandas
result_dict[column_name] = column_as_numpy

return schema.make_namedtuple(**result_dict)

Expand Down Expand Up @@ -236,11 +244,7 @@ def _load_rows_with_predicate(self, pq_file, piece, worker_predicate, shuffle_ro
return pa.Table.from_pandas(result, preserve_index=False)

def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_drop_partition):
table = piece.read(
open_file_func=lambda _: pq_file,
columns=column_names,
partitions=self._dataset.partitions
)
table = compat_piece_read(piece, lambda _: pq_file, columns=column_names, partitions=self._dataset.partitions)

num_rows = len(table)
num_partitions = shuffle_row_drop_partition[1]
Expand Down
64 changes: 64 additions & 0 deletions petastorm/compat.py
@@ -0,0 +1,64 @@
# Copyright (c) 2017-2018 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""0.15.0 cancelled previously deprecated functions. We still want to support 0.11 as it is being used by some users.
This file implements compatibility interfaces. Once we drop support of 0.11, we can get rid of this file."""

import pyarrow as pa
from packaging import version
from pyarrow import parquet as pq

_PYARROW_BEFORE_013 = version.parse(pa.__version__) < version.parse('0.13.0')


def compat_get_metadata(piece, open_func):
if _PYARROW_BEFORE_013:
arrow_metadata = piece.get_metadata(open_func)
else:
arrow_metadata = piece.get_metadata()
return arrow_metadata


def compat_piece_read(piece, open_file_func, **kwargs):
if _PYARROW_BEFORE_013:
table = piece.read(open_file_func=open_file_func, **kwargs)
else:
table = piece.read(**kwargs)
return table


def compat_table_columns_gen(table):
if _PYARROW_BEFORE_013:
for column in table.columns:
name = column.name
yield name, column
else:
for name in table.column_names:
column = table.column(name)
yield name, column


def compat_column_num_chunks(column):
if _PYARROW_BEFORE_013:
return column.data.num_chunks
else:
return column.num_chunks


def compat_make_parquet_piece(path, open_file_func, **kwargs):
if _PYARROW_BEFORE_013:
return pq.ParquetDatasetPiece(path, **kwargs)
else:
return pq.ParquetDatasetPiece(path, open_file_func=open_file_func, # pylint: disable=unexpected-keyword-arg
**kwargs)
16 changes: 9 additions & 7 deletions petastorm/etl/dataset_metadata.py
Expand Up @@ -24,6 +24,7 @@
from six.moves.urllib.parse import urlparse

from petastorm import utils
from petastorm.compat import compat_get_metadata, compat_make_parquet_piece
from petastorm.etl.legacy import depickle_legacy_package_name_compatible
from petastorm.fs_utils import FilesystemResolver
from petastorm.unischema import Unischema
Expand Down Expand Up @@ -267,8 +268,8 @@ def load_row_groups(dataset):
# looking up the number of row groups.
row_groups_key = os.path.relpath(piece.path, dataset.paths)
for row_group in range(row_groups_per_file[row_groups_key]):
rowgroups.append(pq.ParquetDatasetPiece(piece.path, row_group=row_group,
partition_keys=piece.partition_keys))
rowgroups.append(compat_make_parquet_piece(piece.path, dataset.fs.open, row_group=row_group,
partition_keys=piece.partition_keys))
return rowgroups


Expand Down Expand Up @@ -304,7 +305,8 @@ def _split_row_groups(dataset):
continue

for row_group in range(row_groups_per_file[relative_path]):
split_piece = pq.ParquetDatasetPiece(piece.path, row_group=row_group, partition_keys=piece.partition_keys)
split_piece = compat_make_parquet_piece(piece.path, dataset.fs.open, row_group=row_group,
partition_keys=piece.partition_keys)
split_pieces.append(split_piece)

return split_pieces
Expand All @@ -321,10 +323,10 @@ def _split_row_groups_from_footers(dataset):
thread_pool = futures.ThreadPoolExecutor()

def split_piece(piece):
metadata = piece.get_metadata(dataset.fs.open)
return [pq.ParquetDatasetPiece(piece.path,
row_group=row_group,
partition_keys=piece.partition_keys)
metadata = compat_get_metadata(dataset.pieces[0], dataset.fs.open)
return [compat_make_parquet_piece(piece.path, dataset.fs.open,
row_group=row_group,
partition_keys=piece.partition_keys)
for row_group in range(metadata.num_row_groups)]

futures_list = [thread_pool.submit(split_piece, piece) for piece in dataset.pieces]
Expand Down
17 changes: 9 additions & 8 deletions petastorm/etl/rowgroup_indexing.py
Expand Up @@ -21,6 +21,7 @@
from six.moves import range

from petastorm import utils
from petastorm.compat import compat_piece_read, compat_make_parquet_piece
from petastorm.etl import dataset_metadata
from petastorm.etl.legacy import depickle_legacy_package_name_compatible
from petastorm.fs_utils import FilesystemResolver
Expand Down Expand Up @@ -92,22 +93,22 @@ def _index_columns(piece_info, dataset_url, partitions, indexers, schema, hdfs_d
libhdfs (java through JNI) or libhdfs3 (C++)
:return: list of indexers containing index data
"""
# Resolver in executor context will get hadoop config from environment
resolver = FilesystemResolver(dataset_url, hdfs_driver=hdfs_driver)
fs = resolver.filesystem()

# Create pyarrow piece
piece = pq.ParquetDatasetPiece(piece_info.path, row_group=piece_info.row_group,
partition_keys=piece_info.partition_keys)
piece = compat_make_parquet_piece(piece_info.path, fs.open, row_group=piece_info.row_group,
partition_keys=piece_info.partition_keys)

# Collect column names needed for indexing
column_names = set()
for indexer in indexers:
column_names.update(indexer.column_names)

# Read columns needed for indexing
# Resolver in executor context will get hadoop config from environment
resolver = FilesystemResolver(dataset_url, hdfs_driver=hdfs_driver)
column_rows = piece.read(
open_file_func=resolver.filesystem().open,
columns=list(column_names),
partitions=partitions).to_pandas().to_dict('records')
column_rows = compat_piece_read(piece, fs.open, columns=list(column_names),
partitions=partitions).to_pandas().to_dict('records')

# Decode columns values
decoded_rows = [utils.decode_row(row, schema) for row in column_rows]
Expand Down
4 changes: 2 additions & 2 deletions petastorm/hdfs/namenode.py
Expand Up @@ -304,10 +304,10 @@ def _try_next_namenode(cls, index_of_nn, list_of_namenodes, user=None):
try:
return idx, \
cls.hdfs_connect_namenode(urlparse('hdfs://' + str(host or 'default')), user=user)
except ArrowIOError:
except ArrowIOError as e:
# This is an expected error if the namenode we are trying to connect to is
# not the active one
logger.debug('Attempted to connect to namenode %s but failed', host)
logger.debug('Attempted to connect to namenode %s but failed: %e', host, str(e))
# It is a problem if we cannot connect to either of the namenodes when tried back-to-back,
# so better raise an error.
raise HdfsConnectError("Unable to connect to HDFS cluster!")
8 changes: 3 additions & 5 deletions petastorm/py_dict_reader_worker.py
Expand Up @@ -22,6 +22,7 @@

from petastorm import utils
from petastorm.cache import NullCache
from petastorm.compat import compat_piece_read
from petastorm.workers_pool import EmptyResultError
from petastorm.workers_pool.worker_base import WorkerBase

Expand Down Expand Up @@ -253,11 +254,8 @@ def _load_rows_with_predicate(self, pq_file, piece, worker_predicate, shuffle_ro
def _read_with_shuffle_row_drop(self, piece, pq_file, column_names, shuffle_row_drop_partition):
# If integer_object_nulls is set to False, nullable integer fields are return as floats
# with nulls translated to nans
data_frame = piece.read(
open_file_func=lambda _: pq_file,
columns=column_names,
partitions=self._dataset.partitions
).to_pandas(integer_object_nulls=True)
data_frame = compat_piece_read(piece, lambda _: pq_file, columns=column_names,
partitions=self._dataset.partitions).to_pandas(integer_object_nulls=True)

num_rows = len(data_frame)
num_partitions = shuffle_row_drop_partition[1]
Expand Down
53 changes: 34 additions & 19 deletions petastorm/tests/test_common.py
Expand Up @@ -19,11 +19,12 @@
from functools import partial

import numpy as np
import pyarrow as pa
import pytz
from pyspark import Row
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, DecimalType, DoubleType, StructField, \
IntegerType, StructType, DateType, TimestampType, ArrayType, ShortType
IntegerType, StructType, DateType, TimestampType, ShortType, ArrayType

from petastorm.codecs import CompressedImageCodec, NdarrayCodec, \
ScalarCodec
Expand Down Expand Up @@ -170,6 +171,9 @@ def create_test_scalar_dataset(output_url, num_rows, num_files=4, spark=None, pa
:param partition_by: A list of fields to partition the parquet store by.
:return: A list of records with a copy of the data written to the dataset.
"""

is_list_of_scalar_broken = pa.__version__ == '0.15.0'

partition_by = partition_by or []
shutdown = False
if not spark:
Expand All @@ -181,36 +185,47 @@ def create_test_scalar_dataset(output_url, num_rows, num_files=4, spark=None, pa
spark = spark_session.getOrCreate()
shutdown = True

expected_data = [{'id': np.int32(i),
'int_fixed_size_list': np.arange(1 + i, 10 + i).astype(np.int32),
'datetime': np.datetime64('2019-01-02'),
'timestamp': np.datetime64('2005-02-25T03:30'),
'string': np.unicode_('hello_{}'.format(i)),
'string2': np.unicode_('world_{}'.format(i)),
'float64': np.float64(i) * .66} for i in range(num_rows)]
def expected_row(i):
result = {'id': np.int32(i),
'datetime': np.datetime64('2019-01-02'),
'timestamp': np.datetime64('2005-02-25T03:30'),
'string': np.unicode_('hello_{}'.format(i)),
'string2': np.unicode_('world_{}'.format(i)),
'float64': np.float64(i) * .66}
if not is_list_of_scalar_broken:
result['int_fixed_size_list'] = np.arange(1 + i, 10 + i).astype(np.int32)
return result

expected_data_as_scalars = [{k: np.asscalar(v) if isinstance(v, np.generic) else v for k, v in row.items()} for row
expected_data = [expected_row(i) for i in range(num_rows)]

expected_data_as_scalars = [{k: v.item() if isinstance(v, np.generic) else v for k, v in row.items()} for row
in expected_data]

# np.datetime64 is converted to a timezone unaware datetime instances. Working explicitly in UTC so we don't need
# to think about local timezone in the tests
for row in expected_data_as_scalars:
row['timestamp'] = row['timestamp'].replace(tzinfo=pytz.UTC)
row['int_fixed_size_list'] = row['int_fixed_size_list'].tolist()
if not is_list_of_scalar_broken:
row['int_fixed_size_list'] = row['int_fixed_size_list'].tolist()

rows = [Row(**row) for row in expected_data_as_scalars]

maybe_int_fixed_size_list_field = [StructField('int_fixed_size_list', ArrayType(IntegerType(), False), False)] \
if not is_list_of_scalar_broken else []

# WARNING: surprisingly, schema fields and row fields are matched only by order and not name.
# We must maintain alphabetical order of the struct fields for the code to work!!!
schema = StructType([
StructField('datetime', DateType(), False),
StructField('float64', DoubleType(), False),
StructField('id', IntegerType(), False),
StructField('int_fixed_size_list', ArrayType(IntegerType(), False), False),
StructField('string', StringType(), False),
StructField('string2', StringType(), False),
StructField('timestamp', TimestampType(), False),
])
schema = StructType(
[
StructField('datetime', DateType(), False),
StructField('float64', DoubleType(), False),
StructField('id', IntegerType(), False),
] + maybe_int_fixed_size_list_field +
[
StructField('string', StringType(), False),
StructField('string2', StringType(), False),
StructField('timestamp', TimestampType(), False),
])

dataframe = spark.createDataFrame(rows, schema)
dataframe. \
Expand Down
8 changes: 6 additions & 2 deletions petastorm/tests/test_pytorch_dataloader.py
@@ -1,6 +1,7 @@
from decimal import Decimal

import numpy as np
import pyarrow as pa
import pytest
# Must import pyarrow before torch. See: https://github.com/uber/petastorm/blob/master/docs/troubleshoot.rst
import torch
Expand Down Expand Up @@ -162,5 +163,8 @@ def test_with_batch_reader(scalar_dataset, shuffling_queue_capacity):
batch_size=3, shuffling_queue_capacity=shuffling_queue_capacity) as loader:
batches = list(loader)
assert len(scalar_dataset.data) == sum(batch['id'].shape[0] for batch in batches)
assert len(scalar_dataset.data) == sum(batch['int_fixed_size_list'].shape[0] for batch in batches)
assert batches[0]['int_fixed_size_list'].shape[1] == len(scalar_dataset.data[0]['int_fixed_size_list'])

# list types are broken in pyarrow 0.15.0. Don't test list-of-int field
if pa.__version__ != '0.15.0':
assert len(scalar_dataset.data) == sum(batch['int_fixed_size_list'].shape[0] for batch in batches)
assert batches[0]['int_fixed_size_list'].shape[1] == len(scalar_dataset.data[0]['int_fixed_size_list'])
4 changes: 3 additions & 1 deletion petastorm/unischema.py
Expand Up @@ -29,6 +29,8 @@
from pyarrow.lib import StructType as pyStructType
from six import string_types

from petastorm.compat import compat_get_metadata


def _fields_as_tuple(field):
"""Common representation of UnischemaField for equality and hash operators.
Expand Down Expand Up @@ -301,7 +303,7 @@ def from_arrow_schema(cls, parquet_dataset, omit_unsupported_fields=False):
:param omit_unsupported_fields: :class:`Boolean`
:return: A :class:`Unischema` object.
"""
meta = parquet_dataset.pieces[0].get_metadata(parquet_dataset.fs.open)
meta = compat_get_metadata(parquet_dataset.pieces[0], parquet_dataset.fs.open)
arrow_schema = meta.schema.to_arrow_schema()
unischema_fields = []

Expand Down

0 comments on commit 0ddae8f

Please sign in to comment.