Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _validate_ref_impl_exists() -> None:

if op_name_clean not in ref_impls:
if op_name not in _SKIP_OPS:
print("*"*100)
print(op_name_clean)
error_impls.append(op_name)

if error_impls:
Expand All @@ -81,6 +83,13 @@ def register_fake(
_REGISTERED_META_KERNELS.add(op_name)
return _register_fake_original(op_name)

lib.define(
"box_with_nms_limit.out(Tensor scores, Tensor boxes, Tensor batch_splits, float score_thresh, float nms, int detections_per_im, bool soft_nms_enabled, str soft_nms_method, float soft_nms_sigma, float soft_nms_min_score_thres, bool rotated, bool cls_agnostic_bbox_reg, bool input_boxes_include_bg_cls, bool output_classes_include_bg_cls, bool legacy_plus_one, Tensor[]? _caffe2_preallocated_outputs=None, *, Tensor(a!) out_scores, Tensor(b!) out_boxes, Tensor(c!) out_classes, Tensor(d!) batch_splits_out, Tensor(e!) out_keeps, Tensor(f!) out_keeps_size) -> (Tensor(a!) scores, Tensor(b!) boxes, Tensor(c!) classes, Tensor(d!) batch_splits, Tensor(e!) keeps, Tensor(f!) keeps_size)"
)

lib.define(
"box_with_nms_limit(Tensor scores, Tensor boxes, Tensor batch_splits, float score_thresh, float nms, int detections_per_im, bool soft_nms_enabled, str soft_nms_method, float soft_nms_sigma, float soft_nms_min_score_thres, bool rotated, bool cls_agnostic_bbox_reg, bool input_boxes_include_bg_cls, bool output_classes_include_bg_cls, bool legacy_plus_one, Tensor[]? _caffe2_preallocated_outputs=None) -> (Tensor scores, Tensor boxes, Tensor classes, Tensor batch_splits, Tensor keeps, Tensor keeps_size)"
)

lib.define(
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
Expand Down Expand Up @@ -2734,6 +2743,48 @@ def quantized_w8a32_gru_meta(
return hidden.new_empty((2, hidden.shape[-1]), dtype=torch.float32)



@register_fake("cadence::box_with_nms_limit")
def box_with_nms_limit_meta(
tscores: torch.Tensor,
tboxes: torch.Tensor,
tbatch_splits: torch.Tensor,
score_thres: float,
nms_thres: float,
detections_per_im: int,
soft_nms_enabled: bool,
soft_nms_method_str: str,
soft_nms_sigma: float,
soft_nms_min_score_thres: float,
rotated: bool,
cls_agnostic_bbox_reg: bool,
input_boxes_include_bg_cls: bool,
output_classes_include_bg_cls: bool,
legacy_plus_one: bool,
optional_tensor_list: Optional[list[torch.Tensor]] = None,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
box_dim = 5 if rotated else 4
assert detections_per_im != 0
batch_size = tbatch_splits.size(0)
num_classes = tscores.size(1)
out_scores = tscores.new_empty([detections_per_im])
out_boxes = tscores.new_empty([detections_per_im, box_dim])
out_classes = tscores.new_empty([detections_per_im])
batch_splits_out = tscores.new_empty([batch_size])
out_keeps = tscores.new_empty([detections_per_im], dtype=torch.int32)
out_keeps_size = tscores.new_empty([batch_size, num_classes], dtype=torch.int32)

return (
out_scores,
out_boxes,
out_classes,
batch_splits_out,
out_keeps,
out_keeps_size,
)

# Validate that all meta kernels have reference implementations
# This is called at module import time to catch missing implementations early
_validate_ref_impl_exists()
44 changes: 44 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

from typing import Optional

m = Library("cadence", "IMPL", "CompositeExplicitAutograd")
torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib")

Expand Down Expand Up @@ -2146,3 +2148,45 @@ def quantized_softmax(
out_scale,
out_zero_point,
)


@impl_tracked(m, "box_with_nms_limit")
def meta_box_with_nms_limit(
tscores: torch.Tensor,
tboxes: torch.Tensor,
tbatch_splits: torch.Tensor,
score_thres: float,
nms_thres: float,
detections_per_im: int,
soft_nms_enabled: bool,
soft_nms_method_str: str,
soft_nms_sigma: float,
soft_nms_min_score_thres: float,
rotated: bool,
cls_agnostic_bbox_reg: bool,
input_boxes_include_bg_cls: bool,
output_classes_include_bg_cls: bool,
legacy_plus_one: bool,
optional_tensor_list: Optional[list[torch.Tensor]] = None,
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:

return torch.ops._caffe2.BoxWithNMSLimit(
tscores,
tboxes,
tbatch_splits,
score_thres,
nms_thres,
detections_per_im,
soft_nms_enabled,
soft_nms_method_str,
soft_nms_sigma,
soft_nms_min_score_thres,
rotated,
cls_agnostic_bbox_reg,
input_boxes_include_bg_cls,
output_classes_include_bg_cls,
legacy_plus_one,
optional_tensor_list,
)
25 changes: 25 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,30 @@ def call_operator(
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceCaffe2BoxWithNMSLimitWithCadenceBoxWithNMSLimit(ExportPass):
"""Replaces _caffe2 BoxWithNMSLimit ops with Cadence BoxWithNMSLimit ops.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
ns = exir_ops.edge if isinstance(op, EdgeOpOverload) else torch.ops
if op != ns._caffe2.BoxWithNMSLimit.default:
return super().call_operator(op, args, kwargs, meta)

return super().call_operator(
exir_ops.edge.cadence.box_with_nms_limit.default,
args,
kwargs,
meta,
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
"""
Expand Down Expand Up @@ -2162,6 +2186,7 @@ class CadenceReplaceOpsInGraph:
ReplaceScalarTensorWithFullPass,
ReplaceInfArgInFullWithValuePass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplaceCaffe2BoxWithNMSLimitWithCadenceBoxWithNMSLimit,
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
ReplaceAtenAvgPoolWithCadenceAvgPoolPass,
ReplaceWhereWithFullArgsWithWhereScalar,
Expand Down
Loading