Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6542,6 +6542,45 @@ def fn(x):

self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False)

def test_inductor_bucketize(self):
def fn(input, boundaries, out_int32, right):
return torch.ops.prims._inductor_bucketize(
input, boundaries, out_int32=out_int32, right=right
)

input = torch.rand((64, 64)) * 2 - 1
boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])

for out_int32 in [True, False]:
for right in [True, False]:
out_int32 = True
right = False
self.common(fn, (input, boundaries, out_int32, right), check_lowp=False)

def test_inductor_bucketize_default_kwargs(self):
def fn(input, offsets):
return torch.ops.prims._inductor_bucketize(input, offsets)

input = torch.tensor(
[-1.0, -0.9, -0.8, -0.5, 0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 0.9, 0.91]
)
offsets = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])

self.common(fn, (input, offsets), check_lowp=False)

def test_inductor_bucketize_int(self):
def fn(input, offsets, out_int32, right):
return torch.ops.prims._inductor_bucketize(
input, offsets, out_int32=out_int32, right=right
)

input = torch.randint(0, 102, (64, 64))
offsets = torch.arange(10, dtype=torch.int32) ** 2 + 1

for out_int32 in [True, False]:
for right in [True, False]:
self.common(fn, (input, offsets, out_int32, right), check_lowp=False)


@dataclasses.dataclass
class TestFailure:
Expand Down
3 changes: 3 additions & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def run(*ex, **kwargs):
"test_empty2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_index3_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_inductor_bucketize_dynamic_shapes": TestFailure(("cpu")),
"test_inductor_bucketize_default_kwargs_dynamic_shapes": TestFailure(("cpu")),
"test_inductor_bucketize_int_dynamic_shapes": TestFailure(("cpu")),
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda")),
Expand Down
39 changes: 39 additions & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,19 @@ def store(self, name, index, value, mode=None):
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
raise NotImplementedError()

def bucketize(
self,
values,
offsets_name: str,
offsets_size,
indexing_dtype: torch.dtype,
right: bool,
):
"""
See [Note: Inductor bucketize op]
"""
raise NotImplementedError()

def __enter__(self):
class CSEProxy:
self.name = "CSEProxy"
Expand Down Expand Up @@ -833,6 +846,32 @@ def reduction(name, dtype, src_dtype, reduction_type, index, value):
name, dtype, src_dtype, reduction_type, index, value
)

@staticmethod
def bucketize(
values,
offsets_name: str,
offsets_size,
indexing_dtype: torch.dtype,
right: bool,
):
"""
[Note: Inductor bucketize op]

Given values (tensor) and offsets_name (reference to the name of a 1D
tensor), calculate the bucket that each value belongs to.

e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
return = [ 0, 1, 1, 1, 1, 3, 3, 4].

When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).

Offsets must be non-decreasing or the result is undefined.
"""
return self.bucketize(
values, offsets_name, offsets_size, indexing_dtype, right
)

super().__enter__()
parent_handler = self.overrides(V.get_ops_handler())
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
Expand Down
31 changes: 31 additions & 0 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,37 @@ def store(self, name, index, value, mode=None):
if not self.inside_reduction:
self.outside_loop_vars.add(value)

def bucketize(
self,
values: CSEVariable,
offsets_name: str,
offsets_size,
indexing_dtype: torch.dtype,
right: bool,
):
"""
See [Note: Inductor bucketize op]
"""

offsets_ptr = self.args.input(offsets_name)
block_size = self.dense_size_str()

if indexing_dtype == torch.int32:
triton_dtype = "tl.int32"
elif indexing_dtype == torch.int64:
triton_dtype = "tl.int64"
else:
raise NotImplementedError(
"Bucketize only supports indexing with int32 and int64"
)

result = self.cse.generate(
self.compute,
f"triton_helpers.bucketize_binary_search({values}, {offsets_ptr}, {triton_dtype}, {right}, {offsets_size}, {block_size})", # noqa: B950 line too long
)

return result

def reduction_resize(self, value):
ndims = self.triton_tensor_ndim()
if ndims == 1:
Expand Down
22 changes: 22 additions & 0 deletions torch/_inductor/inductor_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from torch import _prims
from torch._prims_common import RETURN_TYPE

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -71,3 +72,24 @@ def eager_force_stride(input_tensor, stride):
lambda input_tensor, stride: eager_force_stride(input_tensor, stride),
doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
)


def _inductor_bucketize_impl(input, boundaries, *, out_int32=False, right=False):
return torch.bucketize(input, boundaries, out_int32=out_int32, right=right)


def _inductor_bucketize_meta(input, boundaries, *, out_int32=False, right=False):
return torch.empty_like(
input,
memory_format=torch.preserve_format,
dtype=(torch.int32 if out_int32 else torch.int64),
)


_bucketize = _prims._make_prim(
schema="_inductor_bucketize(Tensor input, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor",
meta=_inductor_bucketize_meta,
impl_aten=_inductor_bucketize_impl,
return_type=RETURN_TYPE.NEW,
doc="Same as torch.bucketize(), but does not get decomposed.",
)
43 changes: 43 additions & 0 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .ir import (
ExpandView,
IndexingConstant,
is_triton,
ops_wrapper,
PermuteView,
Pointwise,
Expand Down Expand Up @@ -1556,6 +1557,48 @@ def inner_fn(index):
)


@register_lowering(inductor_prims._bucketize, type_promotion_kind=None)
def _inductor_bucketize(
input: TensorBox,
boundaries: TensorBox,
*,
out_int32: bool = False,
right: bool = False,
):
assert len(boundaries.get_size()) == 1

if not (is_triton(input) and is_triton(boundaries)):
return fallback_handler(inductor_prims._bucketize, add_to_fallback_set=False)(
input, boundaries, out_int32=out_int32, right=right
)

boundaries_size = boundaries.get_size()[0]
boundaries_loader = boundaries.make_loader()
device = input.get_device()
input_loader = input.make_loader()

index_dtype = torch.int32 if out_int32 else torch.int64

def inner_fn(index):
val = input_loader(index)
indices = ops.bucketize(
val,
boundaries.get_name(),
ops.index_expr(boundaries_size, index_dtype),
index_dtype,
right,
)

return indices

return Pointwise.create(
device=device,
dtype=index_dtype,
inner_fn=inner_fn,
ranges=input.get_size(),
)


def require_dense(_, *args, **kwargs):
args, kwargs = pytree.tree_map_only(
ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs)
Expand Down
34 changes: 34 additions & 0 deletions torch/_inductor/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,37 @@ def _any_combine(a, b):
@triton.jit
def any(a, dim):
return tl.reduce(a, dim, _any_combine)


@triton.jit
def bucketize_binary_search(
values, # 1D tensor
offsets_ptr,
indexing_dtype,
right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
OFFSETS_SIZE: int,
BLOCK_SHAPE, # tuple/list of block shape
):
"""
See [Note: Inductor bucketize op]
"""

low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)

full_range = OFFSETS_SIZE + 1
while full_range > 1:
mid = (high + low) // 2
mask = mid < OFFSETS_SIZE
bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask)
if right:
is_above = values >= bucket_upper_bound
else:
is_above = values > bucket_upper_bound

low = tl.where(is_above & mask, mid + 1, low)
high = tl.where(is_above, high, mid)

full_range = (full_range + 1) // 2

return low