Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mangle signed and unsigned integer types differently #1340

Merged
merged 1 commit into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,43 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)


def test_unsigned_name_mangling(device='cuda'):
# Test that uint32 and int32 are mangled differently by the compiler
SIZE = 128
# define the kernel / launch-grid

@triton.jit
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y + off)
out1 = tl.abs(x) # uint32 -> nop
out2 = tl.abs(-y) # int32 -> should have an effect
tl.store(O1 + off, out1)
tl.store(O2 + off, out2)

dtype_x = 'uint32'
dtype_y = 'int32'
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
# reference result
expect = (np.abs(x), np.abs(-y))
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device, dst_type=dtype_y)
actual = tuple(
to_triton(np.empty_like(e), device=device)
for e in expect
)
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)

# Bitwise op, so expect exact equality
assert (expect[0] == to_numpy(actual[0])).all()
assert (expect[1] == to_numpy(actual[1])).all()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On current master this fails with negative numbers in the triton ouput.



# ---------------
# test bitwise ops
# ---------------
Expand Down
4 changes: 3 additions & 1 deletion python/triton/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
SIGNED = triton.language.dtype.SIGNEDNESS.SIGNED
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
return prefix + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
Expand Down