Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,42 @@ def foo(inp):
foo_c = torch.compile(foo)
torch.testing.assert_allclose(foo(inp), foo_c(inp))

@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
def test_float8_e8m0fnu(self):
device = "cuda"
dtype = torch.float8_e8m0fnu
hp_dtype = torch.float32 # and torch.bfloat16

def foo(x0):
x1 = x0.to(dtype)
x2 = x1.to(hp_dtype)
return x2

x0 = torch.randn(16, 16, device=device, dtype=hp_dtype)
foo_c = torch.compile(foo, backend="inductor", fullgraph=True)

with torch.no_grad():
y_c = foo_c(x0)

self.assertEqual(foo(x0), y_c)

dtype = torch.float8_e8m0fnu

def foo(x0):
x1 = x0 + 1
x2 = x1.view(dtype)
return x2

x0 = torch.randint(0, 255, (16, 16), device=device, dtype=torch.uint8)
foo_c = torch.compile(foo, backend="inductor", fullgraph=True)

with torch.no_grad():
y_c = foo_c(x0)

self.assertEqual(foo(x0), y_c)

@unittest.skipIf(
not config.is_fbcode(),
"bfloat16 atomic add is only supported in fbcode today #97016",
Expand Down
25 changes: 21 additions & 4 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
is_dynamic,
is_gpu,
is_pointwise_use,
is_view,
needs_fallback_due_to_atomic_add_limitations,
pad_listlike,
register_op_dtype_propagation_rules,
Expand Down Expand Up @@ -1914,7 +1915,7 @@ def _warn_complex_not_supported():

# There are some types (CPU) which we accept as input but not as
# output.
def unsupported_input_tensor(t: torch.Tensor, parent=None):
def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
"Do not support reading or writing to this tensor"
if t.is_complex():
# Complex views are supported with IR ComplexView
Expand All @@ -1925,10 +1926,26 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None):
return False
_warn_complex_not_supported()
return True

if t.dtype == torch.float8_e8m0fnu:
if not node:
return True

# allow bitcast, views, memory movement, but not arithmetic
# TODO: delete once triton adds native support
return not (
node.target
in (
aten.view.dtype,
aten.cat.default,
)
or is_view(node.target)
)

return False


def unsupported_output_tensor(t: torch.Tensor, parent=None):
def unsupported_output_tensor(t: torch.Tensor, parent=None, node=None):
"Do not support writing tensor but can read from it"
if unsupported_input_tensor(t, parent):
return True
Expand Down Expand Up @@ -1956,10 +1973,10 @@ def check_skip_condition(node, parent, is_output):
continue

if is_output:
if unsupported_output_tensor(meta, parent):
if unsupported_output_tensor(meta, parent, node):
return True
else:
if unsupported_input_tensor(meta, parent):
if unsupported_input_tensor(meta, parent, node):
return True

return False
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def _type_of(key: Optional[torch.dtype]) -> str:
"float8e4b15x4": "fp8e4b15x4",
"float8_e4m3fn": "fp8e4nv",
"float8_e5m2": "fp8e5",
# TODO: remove when support is added in triton
# https://github.com/triton-lang/triton/issues/6054
"float8_e8m0fnu": "u8",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
Expand Down Expand Up @@ -2458,6 +2461,9 @@ def normalize_name(name: str) -> str:
"tl.float8_e5m2": "tl.float8e5",
"tl.float8_e4m3fnuz": "tl.float8e4b8",
"tl.float8_e5m2fnuz": "tl.float8e5b16",
# TODO: remove when support is added in triton
# https://github.com/triton-lang/triton/issues/6054
"tl.float8_e8m0fnu": "tl.uint8",
}
_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}

Expand Down
1 change: 1 addition & 0 deletions torch/fx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def _rename_object(self, obj: Any, name: str):
torch.float8_e5m2: "f8e5m2",
torch.float8_e4m3fnuz: "f8e4m3fnuz",
torch.float8_e5m2fnuz: "f8e5m2fnuz",
torch.float8_e8m0fnu: "f8e8m0fnu",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
Expand Down
Loading