Skip to content

Commit

Permalink
[Inductor] Fallback scatter when src dtype is bf16 (#113204)
Browse files Browse the repository at this point in the history
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

Pull Request resolved: #113204
Approved by: https://github.com/eellison
  • Loading branch information
oulgen authored and pytorchmergebot committed Nov 9, 2023
1 parent 31ded95 commit fbf7866
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
14 changes: 14 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,20 @@ def fn(a):

self.common(fn, [torch.linspace(-10, 10, 41)])

def test_scatter_bf16(self):
def fn(inp, src, index):
return inp.scatter_add(0, index, src)

for dtype in [torch.int64, torch.bool, torch.bfloat16]:
self.common(
fn,
[
torch.zeros(3, 5, dtype=dtype),
torch.ones((2, 5), dtype=dtype),
torch.tensor([[0, 1, 2, 0, 0]]),
],
)

def test_randn_generator(self):
def fn(a, generator):
torch.randn([20, 20], generator=generator, device=a.device)
Expand Down
12 changes: 10 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2927,6 +2927,11 @@ def _unsafe_index_put_(self, indices, values, accumulate=False):
return index_put_impl_(self, indices, values, accumulate, check=False)


def needs_fallback_due_to_atomic_add_limitations(dtype):
# tl.atomic_add does NOT support the following types
return dtype in {torch.int64, torch.bool, torch.bfloat16}


def index_put_impl_(self, indices, values, accumulate, check):
# Dispatch to masked fill for single boolean index with single value
if (
Expand All @@ -2951,8 +2956,7 @@ def index_put_impl_(self, indices, values, accumulate, check):
x_size = self.get_size()
x_ndim = len(x_size)

# fallback to aten.index_put_, as tl.atomic_add does NOT support int64 or bool
if self.get_dtype() in {torch.int64, torch.bool}:
if needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
# self is an scalar Tensor
if x_ndim == 0:
self = view(self, [1])
Expand Down Expand Up @@ -3080,6 +3084,10 @@ def scatter_fallback(
reduce_ty = "add" if fn == "aten.scatter_" else "sum"
if (
reduce not in {None, reduce_ty}
or (
isinstance(src, TensorBox)
and needs_fallback_due_to_atomic_add_limitations(src.get_dtype())
)
or (
fn == "aten.scatter_reduce_"
and reduce == "sum"
Expand Down

0 comments on commit fbf7866

Please sign in to comment.