Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@
"roi_align_box_processor(Tensor rois, int output_size_h, int output_size_w, "
"int sampling_ratio, bool aligned) -> (Tensor out)"
)
lib.define(
"_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)"
)
lib.define(
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
)

# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
aten_lib = Library("aten", "FRAGMENT")
Expand Down Expand Up @@ -2075,3 +2081,13 @@ def roi_align_box_processor_meta(
aligned: bool,
) -> torch.Tensor:
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)


@register_fake("cadence::_softmax_f32_f32")
def softmax_f32_f32_meta(
self: torch.Tensor,
dim: int,
dtype: torch.dtype,
half_to_float: Optional[bool] = None,
) -> torch.Tensor:
return self.new_empty(self.size(), dtype=self.dtype)
7 changes: 7 additions & 0 deletions backends/cadence/aot/type_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ class CompileTimeTypeDispatchPass(ExportPass):
},
weight_arg_idx=3,
),
exir_ops.edge.aten._softmax.default: OpConfig(
"_softmax",
type_dispatch_suffixes={
(torch.float32,): "f32_f32",
},
variant="default",
),
}

def call_operator(
Expand Down
158 changes: 158 additions & 0 deletions backends/cadence/hifi/operators/op_softmax_f32_f32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/runtime/kernel/kernel_includes.h>

using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::runtime::KernelRuntimeContext;
using torch::executor::Error;

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {

inline Tensor& _softmax_f32_f32_out(
KernelRuntimeContext& ctx,
const Tensor& in,
int64_t dim,
::executorch::aten::optional<bool> half_to_float,
Tensor& out) {
constexpr int kNnlibMaxDim = 16;

const std::optional<int64_t>& dim_t = dim;
const size_t d = ET_NORMALIZE_IX(dim_t.value(), in.dim());
const size_t size = in.size(d);

size_t stride = 1, outer_size = 1;

size_t outer_stride = 1;

int* p_inp = (int*)in.const_data_ptr<float>();
int* out_data = (int*)out.mutable_data_ptr<float>();

int num_inp_dims = in.dim();
int num_out_dims = num_inp_dims;

int p_inp_shape[kNnlibMaxDim];
int p_out_shape[kNnlibMaxDim];
int p_permute_vec[kNnlibMaxDim];

for (int i = 0; i < num_inp_dims; i++)
p_inp_shape[i] = in.size(i);
for (int i = 0; i < num_inp_dims; i++) {
if (i == d)
p_permute_vec[i] = num_inp_dims - 1;
else if (i == (num_inp_dims - 1))
p_permute_vec[num_inp_dims - 1] = d;
else
p_permute_vec[i] = i;

p_out_shape[i] = p_inp_shape[p_permute_vec[i]];

if (i != d)
outer_size = outer_size * p_inp_shape[i];
}

outer_stride = size;

WORD32 ret_val = 0;

// Check if the input is permuted. If not, then we don't need to transpose
bool is_permuted = false;
for (int i = 0; i < num_inp_dims; i++) {
if (p_permute_vec[i] != i) {
is_permuted = true;
break;
}
}

if (!is_permuted) {
const float* p_inpf = in.const_data_ptr<float>();
float* out_dataf = out.mutable_data_ptr<float>();

for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
size_t outer = outer_idx * outer_stride;
for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
size_t base = outer + inner_idx;

float* p_in_data = (float*)&p_inpf[base];
float* p_out_data = (float*)&out_dataf[base];

ret_val = xa_nn_vec_softmax_f32_f32(p_out_data, p_in_data, size);

ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
}
}
return out;
}

int* p_out =
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));

ET_KERNEL_CHECK(ctx, p_out != nullptr, MemoryAllocationFailed, out);

int* p_out1 =
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));

ET_KERNEL_CHECK(ctx, p_out1 != nullptr, MemoryAllocationFailed, out);

ret_val = xa_nn_transpose_32_32(
p_out,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);

for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
size_t outer = outer_idx * outer_stride;
for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
size_t base = outer + inner_idx;

float* p_in_data = (float*)&p_out[base];
float* p_out_data = (float*)&p_out1[base];

ret_val = xa_nn_vec_softmax_f32_f32(p_out_data, p_in_data, size);

ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
}
}

ret_val = xa_nn_transpose_32_32(
out_data,
p_inp_shape,
p_out1,
p_out_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);

return out;
}

Tensor& softmax_f32_f32_out(
KernelRuntimeContext& ctx,
const Tensor& in,
int64_t dim,
::executorch::aten::optional<bool> half_to_float,
Tensor& out) {
return _softmax_f32_f32_out(ctx, in, dim, half_to_float, out);
}

} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
1 change: 1 addition & 0 deletions backends/cadence/hifi/operators/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ OPERATORS = [
"sigmoid",
"slice_copy",
"softmax",
"softmax_f32_f32",
"split_with_sizes_copy",
"sub",
"tanh",
Expand Down
Loading