Skip to content

Commit

Permalink
[Inductor] Fallback scatter when src dtype is bf16
Browse files Browse the repository at this point in the history
basic_gnn_gcn, basic_gnn_gin, basic_gnn_sage now pass

ghstack-source-id: ca07bd6e82a81bc327e0c26216e6841c8a96987b
Pull Request resolved: #113204
  • Loading branch information
oulgen committed Nov 9, 2023
1 parent ee777a7 commit e4ac291
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
14 changes: 14 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,20 @@ def fn(a):

self.common(fn, [torch.linspace(-10, 10, 41)])

def test_scatter_bf16(self):
def fn(inp, src, index):
return inp.scatter_add(0, index, src)

for dtype in [torch.int64, torch.bool, torch.bfloat16]:
self.common(
fn,
[
torch.zeros(3, 5, dtype=dtype),
torch.ones((2, 5), dtype=dtype),
torch.tensor([[0, 1, 2, 0, 0]]),
],
)

def test_randn_generator(self):
def fn(a, generator):
torch.randn([20, 20], generator=generator, device=a.device)
Expand Down
12 changes: 10 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2926,6 +2926,11 @@ def _unsafe_index_put_(self, indices, values, accumulate=False):
return index_put_impl_(self, indices, values, accumulate, check=False)


def needs_fallback_due_to_atomic_add_limitations(dtype):
# tl.atomic_add does NOT support the following types
return dtype in {torch.int64, torch.bool, torch.bfloat16}


def index_put_impl_(self, indices, values, accumulate, check):
# Dispatch to masked fill for single boolean index with single value
if (
Expand All @@ -2950,8 +2955,7 @@ def index_put_impl_(self, indices, values, accumulate, check):
x_size = self.get_size()
x_ndim = len(x_size)

# fallback to aten.index_put_, as tl.atomic_add does NOT support int64 or bool
if self.get_dtype() in {torch.int64, torch.bool}:
if needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
# self is an scalar Tensor
if x_ndim == 0:
self = view(self, [1])
Expand Down Expand Up @@ -3079,6 +3083,10 @@ def scatter_fallback(
reduce_ty = "add" if fn == "aten.scatter_" else "sum"
if (
reduce not in {None, reduce_ty}
or (
isinstance(src, TensorBox)
and needs_fallback_due_to_atomic_add_limitations(src.get_dtype())
)
or (
fn == "aten.scatter_reduce_"
and reduce == "sum"
Expand Down

0 comments on commit e4ac291

Please sign in to comment.