Skip to content

Commit

Permalink
Replace kwargs with factories.
Browse files Browse the repository at this point in the history
As title.
  • Loading branch information
stuartarchibald committed Oct 3, 2018
1 parent d68de2f commit 90acad4
Showing 1 changed file with 54 additions and 46 deletions.
100 changes: 54 additions & 46 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,55 +678,63 @@ def nan_aware_less_than(a, b):
else:
return a < b

@register_jitable
def _partition(A, low, high, should_pivot=less_than):
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.

# Use median of three {low, middle, high} as the pivot
if should_pivot(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
if should_pivot(A[high], A[mid]):
A[high], A[mid] = A[mid], A[high]
if should_pivot(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
def _partition_factory(pivotimpl):
def _partition(A, low, high):
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.

# Use median of three {low, middle, high} as the pivot
if pivotimpl(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
if pivotimpl(A[high], A[mid]):
A[high], A[mid] = A[mid], A[high]
if pivotimpl(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]

A[high], A[mid] = A[mid], A[high]
i = low
j = high - 1
while True:
while i < high and should_pivot(A[i], pivot):
A[high], A[mid] = A[mid], A[high]
i = low
j = high - 1
while True:
while i < high and pivotimpl(A[i], pivot):
i += 1
while j >= low and pivotimpl(pivot, A[j]):
j -= 1
if i >= j:
break
A[i], A[j] = A[j], A[i]
i += 1
while j >= low and should_pivot(pivot, A[j]):
j -= 1
if i >= j:
break
A[i], A[j] = A[j], A[i]
i += 1
j -= 1
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
A[i], A[high] = A[high], A[i]
return i
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
A[i], A[high] = A[high], A[i]
return i
return _partition

@register_jitable
def _select(arry, k, low, high, should_pivot=less_than):
"""
Select the k'th smallest element in array[low:high + 1].
"""
i = _partition(arry, low, high, should_pivot)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high, should_pivot)
else:
high = i - 1
i = _partition(arry, low, high, should_pivot)
return arry[k]
_partition = register_jitable(_partition_factory(less_than))
_partition_w_nan = register_jitable(_partition_factory(nan_aware_less_than))

def _select_factory(partitionimpl):
def _select(arry, k, low, high):
"""
Select the k'th smallest element in array[low:high + 1].
"""
i = partitionimpl(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = partitionimpl(arry, low, high)
else:
high = i - 1
i = partitionimpl(arry, low, high)
return arry[k]
return _select

_select = register_jitable(_select_factory(_partition))
_select_w_nan = register_jitable(_select_factory(_partition_w_nan))

@register_jitable
def _select_two(arry, k, low, high):
Expand Down Expand Up @@ -939,7 +947,7 @@ def np_partition_impl_inner(a, kth_array):
high = len(dense) - 1

for kth in kth_array:
_select(dense, kth, low, high, should_pivot=nan_aware_less_than)
_select_w_nan(dense, kth, low, high)
low = kth # narrow span of subsequent partition

out[s] = dense
Expand Down

0 comments on commit 90acad4

Please sign in to comment.