Skip to content
This repository has been archived by the owner on Jan 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #185 from ndawe/stretch
Browse files Browse the repository at this point in the history
[MRG] stretch() improvements
  • Loading branch information
ndawe committed May 24, 2015
2 parents c332a0f + cc2f178 commit 4108861
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 87 deletions.
89 changes: 46 additions & 43 deletions root_numpy/_utils.py
Expand Up @@ -17,10 +17,6 @@
VLEN = np.vectorize(len)


def _is_object_field(arr, col):
return arr.dtype[col] == 'O'


def rec2array(rec, fields=None):
"""Convert a record array into a ndarray with a homogeneous data type.
Expand Down Expand Up @@ -72,7 +68,7 @@ def stack(recs, fields=None):
return np.hstack([rec[fields] for rec in recs])


def stretch(arr, fields):
def stretch(arr, fields=None):
"""Stretch an array.
Stretch an array by ``hstack()``-ing multiple array fields while
Expand All @@ -83,8 +79,8 @@ def stretch(arr, fields):
----------
arr : NumPy structured or record array
The array to be stretched.
fields : list of strings
A list of column names to stretch.
fields : list of strings, optional (default=None)
A list of column names to stretch. If None, then stretch all fields.
Returns
-------
Expand All @@ -103,44 +99,51 @@ def stretch(arr, fields):
dtype=[('scalar', '<i8'), ('array', '<f8')])
"""
dt = []
has_array_field = False
has_scalar_filed = False
first_array = None

# Construct dtype
for c in fields:
if _is_object_field(arr, c):
dt.append((c, arr[c][0].dtype))
has_array_field = True
first_array = c if first_array is None else first_array
else:
# Assume scalar
dt.append((c, arr[c].dtype))
has_scalar_filed = True

if not has_array_field:
raise RuntimeError("No array column specified")

len_array = VLEN(arr[first_array])
numrec = np.sum(len_array)
ret = np.empty(numrec, dtype=dt)

for c in fields:
if _is_object_field(arr, c):
# FIXME: this is rather inefficient since the stack
# is copied over to the return value
stack = np.hstack(arr[c])
if len(stack) != numrec:
dtype = []
len_array = None

if fields is None:
fields = arr.dtype.names

# Construct dtype and check consistency
for field in fields:
dt = arr.dtype[field]
if dt == 'O' or len(dt.shape):
if dt == 'O':
# Variable-length array field
lengths = VLEN(arr[field])
else:
lengths = np.repeat(dt.shape[0], arr.shape[0])
# Fixed-length array field
if len_array is None:
len_array = lengths
elif not np.array_equal(lengths, len_array):
raise ValueError(
"Array lengths do not match: "
"expected %d but found %d in %s" %
(numrec, len(stack), c))
ret[c] = stack
"inconsistent lengths of array columns in input")
if dt == 'O':
dtype.append((field, arr[field][0].dtype))
else:
dtype.append((field, arr[field].dtype, dt.shape[1:]))
else:
# Scalar field
dtype.append((field, dt))

if len_array is None:
raise RuntimeError("no array column in input")

# Build stretched output
ret = np.empty(np.sum(len_array), dtype=dtype)
for field in fields:
dt = arr.dtype[field]
if dt == 'O' or len(dt.shape) == 1:
# Variable-length or 1D fixed-length array field
ret[field] = np.hstack(arr[field])
elif len(dt.shape):
# Multidimensional fixed-length array field
ret[field] = np.vstack(arr[field])
else:
# FIXME: this is rather inefficient since the repeat result
# is copied over to the return value
ret[c] = np.repeat(arr[c], len_array)
# Scalar field
ret[field] = np.repeat(arr[field], len_array)

return ret

Expand Down
85 changes: 41 additions & 44 deletions root_numpy/tests.py
Expand Up @@ -574,56 +574,53 @@ def test_fill_graph():


def test_stretch():
nrec = 5
arr = np.empty(nrec,
arr = np.empty(5,
dtype=[
('scalar', np.int),
('df1', 'O'),
('df2', 'O'),
('df3', 'O')])

for i in range(nrec):
df1 = np.array(range(i + 1), dtype=np.float)
df2 = np.array(range(i + 1), dtype=np.int) * 2
df3 = np.array(range(i + 1), dtype=np.double) * 3
arr[i] = (i, df1, df2, df3)
('vl1', 'O'),
('vl2', 'O'),
('vl3', 'O'),
('fl1', np.int, (2, 2)),
('fl2', np.float, (2, 3)),
('fl3', np.double, (3, 2))])

for i in range(arr.shape[0]):
vl1 = np.array(range(i + 1), dtype=np.int)
vl2 = np.array(range(i + 2), dtype=np.float) * 2
vl3 = np.array(range(2), dtype=np.double) * 3
fl1 = np.array(range(4), dtype=np.int).reshape((2, 2))
fl2 = np.array(range(6), dtype=np.float).reshape((2, 3))
fl3 = np.array(range(6), dtype=np.double).reshape((3, 2))
arr[i] = (i, vl1, vl2, vl3, fl1, fl2, fl3)

# no array columns included
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])

stretched = rnp.stretch(
arr, ['scalar', 'df1', 'df2', 'df3'])
# lengths don't match
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'vl1', 'vl2',])
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'fl1', 'fl3',])
assert_raises(ValueError, rnp.stretch, arr)

# variable-length stretch
stretched = rnp.stretch(arr, ['scalar', 'vl1',])
assert_equal(stretched.dtype,
[('scalar', np.int),
('df1', np.float),
('df2', np.int),
('df3', np.double)])
assert_equal(stretched.size, 15)

assert_almost_equal(stretched['df1'][14], 4.0)
assert_almost_equal(stretched['df2'][14], 8)
assert_almost_equal(stretched['df3'][14], 12.0)
assert_almost_equal(stretched['scalar'][14], 4)
assert_almost_equal(stretched['scalar'][13], 4)
assert_almost_equal(stretched['scalar'][12], 4)
assert_almost_equal(stretched['scalar'][11], 4)
assert_almost_equal(stretched['scalar'][10], 4)
assert_almost_equal(stretched['scalar'][9], 3)

arr = np.empty(1, dtype=[('scalar', np.int),])
arr[0] = (1,)
assert_raises(RuntimeError, rnp.stretch, arr, ['scalar',])
[('scalar', np.int),
('vl1', np.int)])
assert_equal(stretched.shape[0], 15)
assert_array_equal(
stretched['scalar'],
np.repeat(arr['scalar'], np.vectorize(len)(arr['vl1'])))

nrec = 5
arr = np.empty(nrec,
dtype=[
('scalar', np.int),
('df1', 'O'),
('df2', 'O')])

for i in range(nrec):
df1 = np.array(range(i + 1), dtype=np.float)
df2 = np.array(range(i + 2), dtype=np.int) * 2
arr[i] = (i, df1, df2)
assert_raises(ValueError, rnp.stretch, arr, ['scalar', 'df1', 'df2'])
# fixed-length stretch
stretched = rnp.stretch(arr, ['scalar', 'vl3', 'fl1', 'fl2',])
assert_equal(stretched.dtype,
[('scalar', np.int),
('vl3', np.double),
('fl1', np.int, (2,)),
('fl2', np.float, (3,))])
assert_equal(stretched.shape[0], 10)
assert_array_equal(
stretched['scalar'], np.repeat(arr['scalar'], 2))


def test_blockwise_inner_join():
Expand Down

0 comments on commit 4108861

Please sign in to comment.