Skip to content

Commit

Permalink
MAINT: stats: Work around Cython bug. (#16719)
Browse files Browse the repository at this point in the history
Because of a bug in Cython, the inner function weigh() (a closure
in _weightedrankedtau()) cannot refer to the memoryview arguments
of _weightedrankedtau().  The work-around here is to assign the
arguments to local variables.
  • Loading branch information
WarrenWeckesser committed Jul 28, 2022
1 parent b0a7af7 commit d1e5f1f
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions scipy/stats/_stats.pyx
Expand Up @@ -171,6 +171,11 @@ def _toint64(x):
@cython.wraparound(False)
@cython.boundscheck(False)
def _weightedrankedtau(ordered[:] x, ordered[:] y, intp_t[:] rank, weigher, bool additive):
# y_local and rank_local (declared below) are a work-around for a Cython
# bug; see gh-16718. When we can require Cython 3.0, y_local and
# rank_local can be removed, and the closure weigh() can refer directly
# to y and rank.
cdef ordered[:] y_local = y
cdef intp_t i, first
cdef float64_t t, u, v, w, s, sq
cdef int64_t n = np.int64(len(x))
Expand All @@ -189,6 +194,8 @@ def _weightedrankedtau(ordered[:] x, ordered[:] y, intp_t[:] rank, weigher, bool
rank[...] = perm[::-1]
_invert_in_place(rank)

cdef intp_t[:] rank_local = rank

# weigh joint ties
first = 0
t = 0
Expand Down Expand Up @@ -237,28 +244,28 @@ def _weightedrankedtau(ordered[:] x, ordered[:] y, intp_t[:] rank, weigher, bool
cdef float64_t weight, residual

if length == 1:
return weigher(rank[perm[offset]])
return weigher(rank_local[perm[offset]])
length0 = length // 2
length1 = length - length0
middle = offset + length0
residual = weigh(offset, length0)
weight = weigh(middle, length1) + residual
if y[perm[middle - 1]] < y[perm[middle]]:
if y_local[perm[middle - 1]] < y_local[perm[middle]]:
return weight

# merging
i = j = k = 0

while j < length0 and k < length1:
if y[perm[offset + j]] <= y[perm[middle + k]]:
if y_local[perm[offset + j]] <= y_local[perm[middle + k]]:
temp[i] = perm[offset + j]
residual -= weigher(rank[temp[i]])
residual -= weigher(rank_local[temp[i]])
j += 1
else:
temp[i] = perm[middle + k]
exchanges_weight[0] += weigher(rank[temp[i]]) * (
exchanges_weight[0] += weigher(rank_local[temp[i]]) * (
length0 - j) + residual if additive else weigher(
rank[temp[i]]) * residual
rank_local[temp[i]]) * residual
k += 1
i += 1

Expand Down

0 comments on commit d1e5f1f

Please sign in to comment.