diff --git a/helion/_testing.py b/helion/_testing.py index 8d27d56fd..dda909349 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -429,7 +429,8 @@ def run_example( rtol: Relative tolerance for correctness check (default: 1e-2) atol: Absolute tolerance for correctness check (default: 1e-1) """ - torch.set_float32_matmul_precision("high") + torch.backends.cuda.matmul.fp32_precision = "tf32" + torch.backends.cudnn.conv.fp32_precision = "tf32" # type: ignore[reportAttributeAccessIssue] # Normalize to dict format kernels = kernel_fn if isinstance(kernel_fn, dict) else {kernel_name: kernel_fn}