Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def _validate_ref_impl_exists() -> None:
_WARN_ONLY = {
"cadence::quantized_w8a32_linear",
"cadence::quantized_add", # We should only support per_tensor variant, should remove
"cadence::idma_store",
"cadence::idma_load",
"cadence::_softmax_f32_f32",
"cadence::requantize", # We should only support per_tensor variant, should remove
"cadence::quantized_softmax.per_tensor",
Expand All @@ -70,13 +68,11 @@ def _validate_ref_impl_exists() -> None:
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
"cadence::linalg_svd",
"cadence::quantized_conv2d_nhwc", # We should only support per_tensor variant, should remove
"cadence::idma_copy",
"cadence::quantize_per_tensor_asym16u",
"cadence::dequantize_per_tensor_asym8s",
"cadence::quantize_per_tensor_asym16s",
"cadence::dequantize_per_tensor_asym16s",
"cadence::quantized_softmax",
"cadence::idma_wait",
"cadence::quantized_w8a32_gru",
"cadence::quantized_layer_norm", # We should only support per_tensor variant, should remove
}
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,3 +1636,23 @@ def quantized_embedding_byte(
)

return weight[indices]


@impl_tracked(m, "idma_copy")
def idma_copy(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor:
return src.clone()


@impl_tracked(m, "idma_store")
def idma_store(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor:
return src.clone()


@impl_tracked(m, "idma_load")
def idma_load(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor:
return src.clone()


@impl_tracked(m, "idma_wait")
def idma_wait(src: torch.Tensor, task_num: int = 0, channel: int = 0) -> torch.Tensor:
return src.clone()
Loading