diff --git a/test/test_ops.py b/test/test_ops.py index e260e86f0f..d73ae536ac 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -5,10 +5,12 @@ from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 import unittest from parameterized import parameterized +import pytest # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): # test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) +@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning") @unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") class TestOps(TestCase): def _create_tensors_with_iou(self, N, iou_thresh): diff --git a/torchao/ops.py b/torchao/ops.py index 3a25dbf6db..fcc6ae9364 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,5 +1,14 @@ import torch from torch import Tensor +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + +def register_custom_op(name): + def decorator(func): + if TORCH_VERSION_AFTER_2_4: + return torch.library.register_fake(f"{name}")(func) + else: + return torch.library.impl_abstract(f"{name}")(func) + return decorator def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: """ @@ -9,7 +18,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: # Defines the meta kernel / fake kernel / abstract impl -@torch.library.impl_abstract("torchao::nms") +@register_custom_op("torchao::nms") def _(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") @@ -36,7 +45,7 @@ def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) -@torch.library.impl_abstract("torchao::prepack_fp6_weight") +@register_custom_op("torchao::prepack_fp6_weight") def _(fp6_weight): torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D") return torch.empty_like(fp6_weight) @@ -49,7 +58,7 @@ def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor: return torch.ops.torchao.fp16_to_fp6.default(fp16_tensor) -@torch.library.impl_abstract("torchao::fp16_to_fp6") +@register_custom_op("torchao::fp16_to_fp6") def _(fp16_tensor): torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") @@ -74,7 +83,7 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK) -@torch.library.impl_abstract("torchao::fp16act_fp6weight_linear") +@register_custom_op("torchao::fp16act_fp6weight_linear") def _(_in_feats, _weights, _scales, splitK = 1): torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") @@ -95,7 +104,7 @@ def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) -@torch.library.impl_abstract("torchao::fp6_weight_dequant") +@register_custom_op("torchao::fp6_weight_dequant") def _(fp6_tensor, fp16_scale): torch._check(fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_tensor.dim()}D") torch._check(fp6_tensor.dtype is torch.int32, lambda: f"weight must be INT32, got {fp6_tensor.dtype}")