Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def test_kl_div(self):
),
)
torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=False).to(
"cuda"
device=DEVICE
)
self.assertExpectedJournal(
check_example(
Expand Down
6 changes: 3 additions & 3 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def test_tile_begin(x: torch.Tensor) -> torch.Tensor:
out[tile_m.begin, tile_n.begin] = 1
return out

x = torch.randn(64, 64, device="cuda")
x = torch.randn(64, 64, device=DEVICE)
config = helion.Config(block_sizes=[16, 16])
test_tile_begin.bind((x,)).to_triton_code(config)
result = test_tile_begin.bind((x,)).compile_config(config)(x)
Expand All @@ -272,7 +272,7 @@ def test_tile_end(x: torch.Tensor) -> torch.Tensor:
out[tile_m.end, tile_n.end] = 1
return out

x = torch.randn(64, 64, device="cuda")
x = torch.randn(64, 64, device=DEVICE)
config = helion.Config(block_sizes=[16, 16])
test_tile_end.bind((x,)).to_triton_code(config)
result = test_tile_end.bind((x,)).compile_config(config)(x)
Expand All @@ -285,7 +285,7 @@ def test_tile_id(x: torch.Tensor) -> torch.Tensor:
out[tile_m.id, tile_n.id] = 1
return out

x = torch.randn(64, 64, device="cuda")
x = torch.randn(64, 64, device=DEVICE)
config = helion.Config(block_sizes=[16, 16])
test_tile_id.bind((x,)).to_triton_code(config)
result = test_tile_id.bind((x,)).compile_config(config)(x)
Expand Down
9 changes: 5 additions & 4 deletions test/test_print_ref_eager_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import helion
from helion import exc
from helion._testing import DEVICE
from helion._testing import TestCase
import helion.language as hl

Expand All @@ -35,8 +36,8 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out[tile] = x[tile] + y[tile]
return out

x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
x = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
y = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
torch.testing.assert_close(add(x, y), torch.add(x, y))

def test_normal_mode_code_print(self):
Expand All @@ -61,8 +62,8 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out[tile] = x[tile] + y[tile]
return out

x = torch.randn([512, 512], device="cuda", dtype=torch.float16)
y = torch.randn([512, 512], device="cuda", dtype=torch.float16)
x = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
y = torch.randn([512, 512], device=DEVICE, dtype=torch.float16)
torch.testing.assert_close(add(x, y), torch.add(x, y))

self.assertNotEqual(
Expand Down
11 changes: 6 additions & 5 deletions test/test_ref_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import helion
from helion import exc
from helion._testing import DEVICE
from helion._testing import TestCase
from helion._testing import assert_ref_eager_mode
import helion.language as hl
Expand All @@ -32,8 +33,8 @@ def print_intermediate_tensor_kernel(
out[tile_m, tile_n] = sum_val
return out

x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 10.0
y = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 5.0
x = torch.ones([2, 2], device=DEVICE, dtype=torch.float32) * 10.0
y = torch.ones([2, 2], device=DEVICE, dtype=torch.float32) * 5.0
expected = x + y

# Capture stdout to check print output
Expand Down Expand Up @@ -67,7 +68,7 @@ def incorrect_kernel(x: torch.Tensor) -> torch.Tensor:
pass # noqa: PIE790
return x

x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * math.pi
x = torch.ones([2, 2], device=DEVICE, dtype=torch.float32) * math.pi

# Capture stdout to check print output
captured_output = io.StringIO()
Expand All @@ -89,7 +90,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
return out

with assert_ref_eager_mode():
x = torch.randn(128, 128, device="cuda")
x = torch.randn(128, 128, device=DEVICE)
result = kernel(x)
expected = x * 2.0
torch.testing.assert_close(result, expected)
Expand All @@ -107,7 +108,7 @@ def kernel(x: torch.Tensor) -> torch.Tensor:
# Run the kernel to capture the warning message
captured_stderr = io.StringIO()
with contextlib.redirect_stderr(captured_stderr):
x = torch.randn(128, 128, device="cuda")
x = torch.randn(128, 128, device=DEVICE)
kernel(x)

stderr_output = captured_stderr.getvalue()
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensor_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
block_sizes=[16, 16, 16],
indexing="tensor_descriptor",
)
torch.cuda.synchronize()
torch.accelerator.synchronize()
torch.testing.assert_close(result_large, expected, atol=1e-2, rtol=1e-2)
self.assertIn(get_tensor_descriptor_fn_name(), code_large)

Expand Down
Loading
Loading