Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def skipIfRocm(reason: str) -> Callable[[Callable], Callable]:
return unittest.skipIf(torch.version.hip is not None, reason) # pyright: ignore[reportAttributeAccessIssue]


def skipIfLowVRAM(
reason: str = "Test requires high VRAM",
) -> Callable[[Callable], Callable]:
"""Skip test if HELION_DEV_LOW_VRAM=1 is set"""
return unittest.skipIf(os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1", reason)


@contextlib.contextmanager
def track_run_ref_calls() -> Generator[list[int], None, None]:
"""Context manager that tracks BoundKernel.run_ref calls.
Expand Down
2 changes: 1 addition & 1 deletion test/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def add(a, b):
torch.testing.assert_close(result, sum(args))

def test_autotuner_disabled(self):
@helion.kernel
@helion.kernel(use_default_config=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the env var not override this? 🫨

def add(a, b):
out = torch.empty_like(a)
for tile in hl.tile(out.size()):
Expand Down
4 changes: 3 additions & 1 deletion test/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def autotune(self):

class TestCache(RefEagerTestDisabled, TestCase):
def test_basic(self):
@helion.kernel(autotuner_fn=StrictLocalAutotuneCache[BasicSearch])
@helion.kernel(
autotuner_fn=StrictLocalAutotuneCache[BasicSearch], use_default_config=False
)
def add(x, y):
x, y = torch.broadcast_tensors(x, y)
out = torch.empty_like(x)
Expand Down
5 changes: 5 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from helion._testing import TestCase
from helion._testing import code_and_output
from helion._testing import import_path
from helion._testing import skipIfLowVRAM
from helion._testing import skipIfRefEager
import helion.language as hl

Expand Down Expand Up @@ -53,6 +54,7 @@ def test_pointwise_device_loop(self):
torch.testing.assert_close(result, torch.sigmoid(args[0] + 1))
self.assertExpectedJournal(code)

@skipIfLowVRAM("Test requires high VRAM for [128, 128, 128, 128] tensors")
def test_3d_device_loop0(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
code, result = code_and_output(
Expand All @@ -63,6 +65,7 @@ def test_3d_device_loop0(self):
torch.testing.assert_close(result, torch.sin(args[0]))
self.assertExpectedJournal(code)

@skipIfLowVRAM("Test requires high VRAM for [128, 128, 128, 128] tensors")
def test_3d_device_loop1(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
code, result = code_and_output(
Expand All @@ -74,6 +77,7 @@ def test_3d_device_loop1(self):
torch.testing.assert_close(result, torch.sin(args[0]))
self.assertExpectedJournal(code)

@skipIfLowVRAM("Test requires high VRAM for [128, 128, 128, 128] tensors")
def test_3d_device_loop2(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
code, result = code_and_output(
Expand All @@ -86,6 +90,7 @@ def test_3d_device_loop2(self):
torch.testing.assert_close(result, torch.sin(args[0]))
self.assertExpectedJournal(code)

@skipIfLowVRAM("Test requires high VRAM for [128, 128, 128, 128] tensors")
def test_3d_device_loop3(self):
args = (torch.randn([128, 128, 128, 128], device=DEVICE),)
code, result = code_and_output(
Expand Down
2 changes: 1 addition & 1 deletion test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def kernel_default_config(x: torch.Tensor) -> torch.Tensor:
self.assertIn("def", code_default) # Basic sanity check

# Test 3: Kernel with no configs and no default - should raise error
@helion.kernel
@helion.kernel(use_default_config=False)
def kernel_no_config(x: torch.Tensor) -> torch.Tensor:
result = torch.empty_like(x)
for tile in hl.tile(x.shape):
Expand Down
Loading