Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
1 parent cafbd58 commit f9d3203
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 15 deletions.
12 changes: 1 addition & 11 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,6 @@ def fn(a):
self.common(fn, [torch.linspace(-10, 10, 41)])

@skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
@skip_if_halide # bf16
def test_scatter_bf16(self):
def fn(inp, src, index):
return inp.scatter_add(0, index, src)
Expand Down Expand Up @@ -1921,9 +1920,7 @@ def fn(a, b):
)

dtypes = [torch.float, torch.float16]
if not (self.device == "cuda" and not SM80OrLater) and not is_halide_backend(
self.device
):
if not (self.device == "cuda" and not SM80OrLater):
dtypes += [torch.bfloat16]
for dtype in dtypes:
self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype)))
Expand Down Expand Up @@ -2052,7 +2049,6 @@ def fn(a, b):
self.common(fn, (torch.randn(4, 4), torch.randn(4, 4)))

@skipCUDAIf(not SM80OrLater, "Requires sm80")
@skip_if_halide # bf16
def test_dist_bf16(self):
def fn(a, b):
return torch.dist(a.to(torch.bfloat16), b.to(torch.bfloat16))
Expand Down Expand Up @@ -4011,8 +4007,6 @@ def forward(self, x):

mod = Model().to(self.device)
for dtype in [torch.half, torch.bfloat16]:
if dtype == torch.bfloat16 and is_halide_backend(self.device):
continue
x = torch.randn(4, 3, 7, 7, device=self.device).to(dtype=dtype)
opt_mod = torch.compile(mod)
res = opt_mod(x)
Expand Down Expand Up @@ -5243,7 +5237,6 @@ def fn(x1, x2, x3, x4):
# Constant folding was explicitly turned off due to issue #108388
# Turn it back on for test
@torch._inductor.config.patch(joint_graph_constant_folding=True)
@skip_if_halide # bf16
def test_remove_no_ops(self):
def matmul_with_op(x, y, fn):
return fn(x @ y)
Expand Down Expand Up @@ -9422,7 +9415,6 @@ def fn(x):
self.common(fn, (torch.ones(1, 1, 13, dtype=dtype),))

@unittest.skipIf(not HAS_CPU, "requires C++ compiler")
@skip_if_halide # bf16
def test_data_type_propogation(self):
from torch._dynamo.utils import detect_fake_mode
from torch._inductor.codegen.common import boolean_ops
Expand Down Expand Up @@ -10352,7 +10344,6 @@ def fn(x):
self.assertEqual(ref, actual)

@skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
@skip_if_halide # bf16
def test_bfloat16_to_int16(self):
def fn(a, b):
x = a + b
Expand Down Expand Up @@ -11423,7 +11414,6 @@ def run_with_backward():
torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0),
"Triton does not support fp8 on A100",
)
@skip_if_halide # bf16
def test_red_followed_by_transposed_pointwise(self):
bs = 26624
dim = 1024
Expand Down
3 changes: 1 addition & 2 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def _print_RoundDecimal(self, expr):

_halide_type = {
torch.bool: "hl.Bool()",
torch.bfloat16: "hl.BFloat(16)",
torch.float16: "hl.Float(16)",
torch.float32: "hl.Float(32)",
torch.float64: "hl.Float(64)",
Expand All @@ -218,8 +219,6 @@ def _print_RoundDecimal(self, expr):


def halide_type(dtype):
if dtype == torch.bfloat16:
raise Unsupported("torch.bfloat16")
return _halide_type[dtype]


Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ class HalideInputSpec(typing.NamedTuple):
alias_of: Optional[str] = None

def bindings_type(self):
if self.ctype == "half*":
return "void*" # half not defined
if self.ctype in ("half*", "bfloat16*"):
return "uint16_t*" # half not defined
return self.ctype

def halide_type(self):
if self.ctype == "half*":
return "halide_type_t(halide_type_float, 16)" # half not defined
if self.ctype == "bfloat16*":
return "halide_type_t(halide_type_bfloat, 16)" # half not defined
return f"halide_type_of<{self.ctype.replace('*', '')}>()"

def is_scalar(self):
Expand Down

0 comments on commit f9d3203

Please sign in to comment.