Skip to content

Commit

Permalink
Revert "[FIX] intp -> uintp for cupy"
Browse files Browse the repository at this point in the history
This reverts commit 59ed451.
  • Loading branch information
dcherian committed May 12, 2023
1 parent b6a7edc commit 5df4a75
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def __repr__(self) -> str:
combine="sum",
fill_value=0,
final_fill_value=0,
dtypes=np.uintp,
final_dtype=np.uintp,
dtypes=np.intp,
final_dtype=np.intp,
)

# note that the fill values are the result of np.func([np.nan, np.nan])
Expand Down Expand Up @@ -281,7 +281,7 @@ def _mean_finalize(sum_, count):
combine=("sum", "sum"),
finalize=_mean_finalize,
fill_value=(0, 0),
dtypes=(None, np.uintp),
dtypes=(None, np.intp),
final_dtype=np.floating,
)
nanmean = Aggregation(
Expand All @@ -290,7 +290,7 @@ def _mean_finalize(sum_, count):
combine=("sum", "sum"),
finalize=_mean_finalize,
fill_value=(0, 0),
dtypes=(None, np.uintp),
dtypes=(None, np.intp),
final_dtype=np.floating,
)

Expand All @@ -315,7 +315,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
finalize=_var_finalize,
fill_value=0,
final_fill_value=np.nan,
dtypes=(None, None, np.uintp),
dtypes=(None, None, np.intp),
final_dtype=np.floating,
)
nanvar = Aggregation(
Expand All @@ -325,7 +325,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
finalize=_var_finalize,
fill_value=0,
final_fill_value=np.nan,
dtypes=(None, None, np.uintp),
dtypes=(None, None, np.intp),
final_dtype=np.floating,
)
std = Aggregation(
Expand All @@ -335,7 +335,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
finalize=_std_finalize,
fill_value=0,
final_fill_value=np.nan,
dtypes=(None, None, np.uintp),
dtypes=(None, None, np.intp),
final_dtype=np.floating,
)
nanstd = Aggregation(
Expand All @@ -345,7 +345,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
finalize=_std_finalize,
fill_value=0,
final_fill_value=np.nan,
dtypes=(None, None, np.uintp),
dtypes=(None, None, np.intp),
final_dtype=np.floating,
)

Expand All @@ -368,7 +368,7 @@ def argreduce_preprocess(array, axis):
assert len(axis) == 1
axis = axis[0]

idx = dask.array.arange(array.shape[axis], chunks=array.chunks[axis], dtype=np.uintp)
idx = dask.array.arange(array.shape[axis], chunks=array.chunks[axis], dtype=np.intp)
# broadcast (TODO: is this needed?)
idx = idx[tuple(slice(None) if i == axis else np.newaxis for i in range(array.ndim))]

Expand Down Expand Up @@ -398,8 +398,8 @@ def _pick_second(*x):
fill_value=(dtypes.NINF, 0),
final_fill_value=-1,
finalize=_pick_second,
dtypes=(None, np.uintp),
final_dtype=np.uintp,
dtypes=(None, np.intp),
final_dtype=np.intp,
)

argmin = Aggregation(
Expand All @@ -411,8 +411,8 @@ def _pick_second(*x):
fill_value=(dtypes.INF, 0),
final_fill_value=-1,
finalize=_pick_second,
dtypes=(None, np.uintp),
final_dtype=np.uintp,
dtypes=(None, np.intp),
final_dtype=np.intp,
)

nanargmax = Aggregation(
Expand All @@ -424,8 +424,8 @@ def _pick_second(*x):
fill_value=(dtypes.NINF, 0),
final_fill_value=-1,
finalize=_pick_second,
dtypes=(None, np.uintp),
final_dtype=np.uintp,
dtypes=(None, np.intp),
final_dtype=np.intp,
)

nanargmin = Aggregation(
Expand All @@ -437,8 +437,8 @@ def _pick_second(*x):
fill_value=(dtypes.INF, 0),
final_fill_value=-1,
finalize=_pick_second,
dtypes=(None, np.uintp),
final_dtype=np.uintp,
dtypes=(None, np.intp),
final_dtype=np.intp,
)

first = Aggregation("first", chunk=None, combine=None, fill_value=0)
Expand Down Expand Up @@ -574,10 +574,8 @@ def _initialize_aggregation(
agg.combine += ("sum",)
agg.fill_value["intermediate"] += (0,)
agg.fill_value["numpy"] += (0,)
# uintp is supported by cupy, intp is not
# Also count is >=0, so uint should be fine.
agg.dtype["intermediate"] += (np.uintp,)
agg.dtype["numpy"] += (np.uintp,)
agg.dtype["intermediate"] += (np.intp,)
agg.dtype["numpy"] += (np.intp,)
else:
agg.min_count = 0

Expand Down

0 comments on commit 5df4a75

Please sign in to comment.