Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
import unittest
from parameterized import parameterized
import pytest


# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...):
# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace)
@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning")
@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels")
class TestOps(TestCase):
def _create_tensors_with_iou(self, N, iou_thresh):
Expand Down
19 changes: 14 additions & 5 deletions torchao/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import torch
from torch import Tensor
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4

def register_custom_op(name):
def decorator(func):
if TORCH_VERSION_AFTER_2_4:
return torch.library.register_fake(f"{name}")(func)
else:
return torch.library.impl_abstract(f"{name}")(func)
return decorator

def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
"""
Expand All @@ -9,7 +18,7 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:


# Defines the meta kernel / fake kernel / abstract impl
@torch.library.impl_abstract("torchao::nms")
@register_custom_op("torchao::nms")
def _(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
Expand All @@ -36,7 +45,7 @@ def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor:
return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight)


@torch.library.impl_abstract("torchao::prepack_fp6_weight")
@register_custom_op("torchao::prepack_fp6_weight")
def _(fp6_weight):
torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D")
return torch.empty_like(fp6_weight)
Expand All @@ -49,7 +58,7 @@ def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor:
return torch.ops.torchao.fp16_to_fp6.default(fp16_tensor)


@torch.library.impl_abstract("torchao::fp16_to_fp6")
@register_custom_op("torchao::fp16_to_fp6")
def _(fp16_tensor):
torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D")
torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}")
Expand All @@ -74,7 +83,7 @@ def fp16act_fp6weight_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tenso
return torch.ops.torchao.fp16act_fp6weight_linear.default(_in_feats, _weights, _scales, splitK)


@torch.library.impl_abstract("torchao::fp16act_fp6weight_linear")
@register_custom_op("torchao::fp16act_fp6weight_linear")
def _(_in_feats, _weights, _scales, splitK = 1):
torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D")
torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}")
Expand All @@ -95,7 +104,7 @@ def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor:
return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale)


@torch.library.impl_abstract("torchao::fp6_weight_dequant")
@register_custom_op("torchao::fp6_weight_dequant")
def _(fp6_tensor, fp16_scale):
torch._check(fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_tensor.dim()}D")
torch._check(fp6_tensor.dtype is torch.int32, lambda: f"weight must be INT32, got {fp6_tensor.dtype}")
Expand Down