Skip to content
This repository has been archived by the owner on Jan 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #269 from ndawe/master
Browse files Browse the repository at this point in the history
stretch: if fields is single string (not list) then flatten output
  • Loading branch information
ndawe committed Aug 22, 2016
2 parents d25104f + 6945887 commit 96921a5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
14 changes: 12 additions & 2 deletions root_numpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,11 @@ def stretch(arr, fields=None, return_indices=False):
----------
arr : NumPy structured or record array
The array to be stretched.
fields : list of strings, optional (default=None)
A list of column names to stretch. If None, then stretch all fields.
fields : list of strings or single string, optional (default=None)
A list of column names or a single column name to stretch.
If ``fields`` is a string, then the output array is a one-dimensional
unstructured array containing only the stretched elements of that
field. If None, then stretch all fields.
return_indices : bool, optional (default=False)
If True, the array index of each stretched array entry will be
returned in addition to the stretched array.
Expand All @@ -154,8 +157,12 @@ def stretch(arr, fields=None, return_indices=False):
dtype = []
len_array = None

flatten = False
if fields is None:
fields = arr.dtype.names
elif isinstance(fields, string_types):
fields = [fields]
flatten = True

# Construct dtype and check consistency
for field in fields:
Expand Down Expand Up @@ -197,6 +204,9 @@ def stretch(arr, fields=None, return_indices=False):
# Scalar field
ret[field] = np.repeat(arr[field], len_array)

if flatten:
ret = ret[fields[0]]

if return_indices:
idx = np.concatenate(list(map(np.arange, len_array)))
return ret, idx
Expand Down
4 changes: 4 additions & 0 deletions root_numpy/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,10 @@ def test_stretch():
from_stretched = stretched[idx == 0]['vl1']
assert_array_equal(from_arr, from_stretched)

# stretch single field and produce unstructured output
stretched = rnp.stretch(arr, 'vl1')
assert_equal(stretched.dtype, np.int)


def test_blockwise_inner_join():
test_data = np.array([
Expand Down

0 comments on commit 96921a5

Please sign in to comment.