@@ -1177,11 +1177,13 @@ def group_quantile(
11771177 ndarray[float64_t , ndim = 2 ] out,
11781178 ndarray[numeric_t , ndim = 1 ] values,
11791179 ndarray[intp_t] labels ,
1180- ndarray[uint8_t] mask ,
1181- const intp_t[:] sort_indexer ,
1180+ const uint8_t[:] mask ,
11821181 const float64_t[:] qs ,
1182+ ndarray[int64_t] starts ,
1183+ ndarray[int64_t] ends ,
11831184 str interpolation ,
1184- uint8_t[:, ::1] result_mask = None ,
1185+ uint8_t[:, ::1] result_mask ,
1186+ bint is_datetimelike ,
11851187) -> None:
11861188 """
11871189 Calculate the quantile per group.
@@ -1194,27 +1196,38 @@ def group_quantile(
11941196 Array containing the values to apply the function against.
11951197 labels : ndarray[np.intp]
11961198 Array containing the unique group labels.
1197- sort_indexer : ndarray[np.intp]
1198- Indices describing sort order by values and labels.
11991199 qs : ndarray[float64_t]
12001200 The quantile values to search for.
1201+ starts : ndarray[int64]
1202+ Positions at which each group begins.
1203+ ends : ndarray[int64]
1204+ Positions at which each group ends.
12011205 interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}
1206+ result_mask : ndarray[bool , ndim = 2 ] or None
1207+ is_datetimelike : bool
1208+ Whether int64 values represent datetime64-like values.
12021209
12031210 Notes
12041211 -----
12051212 Rather than explicitly returning a value , this function modifies the
12061213 provided `out` parameter.
12071214 """
12081215 cdef:
1209- Py_ssize_t i , N = len (labels), ngroups , grp_sz , non_na_sz , k , nqs
1210- Py_ssize_t grp_start = 0 , idx = 0
1211- intp_t lab
1216+ Py_ssize_t i , N = len (labels), ngroups , non_na_sz , k , nqs
1217+ Py_ssize_t idx = 0
1218+ Py_ssize_t grp_size
12121219 InterpolationEnumType interp
12131220 float64_t q_val , q_idx , frac , val , next_val
1214- int64_t[::1] counts , non_na_counts
12151221 bint uses_result_mask = result_mask is not None
1222+ Py_ssize_t start , end
1223+ ndarray[numeric_t] grp
1224+ intp_t[::1] sort_indexer
1225+ const uint8_t[:] sub_mask
12161226
12171227 assert values.shape[0] == N
1228+ assert starts is not None
1229+ assert ends is not None
1230+ assert len(starts ) == len(ends )
12181231
12191232 if any(not (0 <= q <= 1) for q in qs ):
12201233 wrong = [x for x in qs if not (0 <= x <= 1 )][0 ]
@@ -1233,64 +1246,65 @@ def group_quantile(
12331246
12341247 nqs = len (qs)
12351248 ngroups = len (out)
1236- counts = np.zeros(ngroups, dtype = np.int64)
1237- non_na_counts = np.zeros(ngroups, dtype = np.int64)
1238-
1239- # First figure out the size of every group
1240- with nogil:
1241- for i in range (N):
1242- lab = labels[i]
1243- if lab == - 1 : # NA group label
1244- continue
12451249
1246- counts[lab] += 1
1247- if not mask[i]:
1248- non_na_counts[lab] += 1
1250+ # TODO: get cnp.PyArray_ArgSort to work with nogil so we can restore the rest
1251+ # of this function as being `with nogil:`
1252+ for i in range (ngroups):
1253+ start = starts[i]
1254+ end = ends[i]
1255+
1256+ grp = values[start:end]
1257+
1258+ # Figure out how many group elements there are
1259+ sub_mask = mask[start:end]
1260+ grp_size = sub_mask.size
1261+ non_na_sz = 0
1262+ for k in range (grp_size):
1263+ if sub_mask[k] == 0 :
1264+ non_na_sz += 1
1265+
1266+ # equiv: sort_indexer = grp.argsort()
1267+ if is_datetimelike:
1268+ # We need the argsort to put NaTs at the end, not the beginning
1269+ sort_indexer = cnp.PyArray_ArgSort(grp.view(" M8[ns]" ), 0 , cnp.NPY_QUICKSORT)
1270+ else :
1271+ sort_indexer = cnp.PyArray_ArgSort(grp, 0 , cnp.NPY_QUICKSORT)
12491272
1250- with nogil:
1251- for i in range (ngroups):
1252- # Figure out how many group elements there are
1253- grp_sz = counts[i]
1254- non_na_sz = non_na_counts[i]
1255-
1256- if non_na_sz == 0 :
1257- for k in range (nqs):
1258- if uses_result_mask:
1259- result_mask[i, k] = 1
1260- else :
1261- out[i, k] = NaN
1262- else :
1263- for k in range (nqs):
1264- q_val = qs[k]
1273+ if non_na_sz == 0 :
1274+ for k in range (nqs):
1275+ if uses_result_mask:
1276+ result_mask[i, k] = 1
1277+ else :
1278+ out[i, k] = NaN
1279+ else :
1280+ for k in range (nqs):
1281+ q_val = qs[k]
12651282
1266- # Calculate where to retrieve the desired value
1267- # Casting to int will intentionally truncate result
1268- idx = grp_start + < int64_t> (q_val * < float64_t> (non_na_sz - 1 ))
1283+ # Calculate where to retrieve the desired value
1284+ # Casting to int will intentionally truncate result
1285+ idx = < int64_t> (q_val * < float64_t> (non_na_sz - 1 ))
12691286
1270- val = values [sort_indexer[idx]]
1271- # If requested quantile falls evenly on a particular index
1272- # then write that index's value out. Otherwise interpolate
1273- q_idx = q_val * (non_na_sz - 1 )
1274- frac = q_idx % 1
1287+ val = grp [sort_indexer[idx]]
1288+ # If requested quantile falls evenly on a particular index
1289+ # then write that index's value out. Otherwise interpolate
1290+ q_idx = q_val * (non_na_sz - 1 )
1291+ frac = q_idx % 1
12751292
1276- if frac == 0.0 or interp == INTERPOLATION_LOWER:
1277- out[i, k] = val
1278- else :
1279- next_val = values[sort_indexer[idx + 1 ]]
1280- if interp == INTERPOLATION_LINEAR:
1281- out[i, k] = val + (next_val - val) * frac
1282- elif interp == INTERPOLATION_HIGHER:
1293+ if frac == 0.0 or interp == INTERPOLATION_LOWER:
1294+ out[i, k] = val
1295+ else :
1296+ next_val = grp[sort_indexer[idx + 1 ]]
1297+ if interp == INTERPOLATION_LINEAR:
1298+ out[i, k] = val + (next_val - val) * frac
1299+ elif interp == INTERPOLATION_HIGHER:
1300+ out[i, k] = next_val
1301+ elif interp == INTERPOLATION_MIDPOINT:
1302+ out[i, k] = (val + next_val) / 2.0
1303+ elif interp == INTERPOLATION_NEAREST:
1304+ if frac > .5 or (frac == .5 and q_val > .5 ): # Always OK?
12831305 out[i, k] = next_val
1284- elif interp == INTERPOLATION_MIDPOINT:
1285- out[i, k] = (val + next_val) / 2.0
1286- elif interp == INTERPOLATION_NEAREST:
1287- if frac > .5 or (frac == .5 and q_val > .5 ): # Always OK?
1288- out[i, k] = next_val
1289- else :
1290- out[i, k] = val
1291-
1292- # Increment the index reference in sorted_arr for the next group
1293- grp_start += grp_sz
1306+ else :
1307+ out[i, k] = val
12941308
12951309
12961310# ----------------------------------------------------------------------
0 commit comments